aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJiacheng Liang <jiacheng.liang@arm.com>2023-07-27 16:50:15 +0100
committerEric Kunze <eric.kunze@arm.com>2023-07-31 15:38:01 +0000
commited1a15b641128e38e6efc9128e7cff97b782446d (patch)
treeb4faaa1b57b4c07fc7b34fba44577d6fe4f26850
parente5cabbf7528849aac35b498ce0711a144c1a08d5 (diff)
downloadreference_model-ed1a15b641128e38e6efc9128e7cff97b782446d.tar.gz
Fixed missing tensor inputs in fully_connected in model runner
Signed-off-by: Jiacheng Liang <jiacheng.liang@arm.com> Change-Id: I473adc1525319b5574ee0e36d10a530277d9215d
-rw-r--r--reference_model/include/operators.h2
-rw-r--r--reference_model/src/operators.cc17
2 files changed, 14 insertions, 5 deletions
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index b12604f..6b4ec58 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -109,6 +109,8 @@ extern "C"
tosa_tensor_t client_output);
tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input,
+ tosa_tensor_t client_weight,
+ tosa_tensor_t client_bias,
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output);
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index f3023dd..1ae0683 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -324,6 +324,8 @@ extern "C"
}
tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input,
+ tosa_tensor_t client_weight,
+ tosa_tensor_t client_bias,
const int32_t client_input_zp,
const int32_t client_weight_zp,
tosa_tensor_t client_output)
@@ -335,21 +337,26 @@ extern "C"
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
+ tosa::TosaSerializationTensor* weight = translate_client_tensor(client_weight, "weight");
+ tosa::TosaSerializationTensor* bias = translate_client_tensor(client_bias, "bias");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_FULLY_CONNECTED,
- tosa::Attribute::Attribute_FullyConnectedAttribute, &attr,
- { input->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(
+ tosa::Op::Op_FULLY_CONNECTED, tosa::Attribute::Attribute_FullyConnectedAttribute, &attr,
+ { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, output },
- { input->GetName() }, { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, weight, bias, output },
+ { input->GetName(), weight->GetName(), bias->GetName() },
+ { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner;
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
+ TOSA_RETURN_ON_ERROR(runner.setInput(weight->GetName(), client_weight.data, client_weight.size));
+ TOSA_RETURN_ON_ERROR(runner.setInput(bias->GetName(), client_bias.data, client_bias.size));
// Execute
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());