diff options
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r-- | reference_model/src/operators.cc | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index f3023dd..1ae0683 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -324,6 +324,8 @@ extern "C" } tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input, + tosa_tensor_t client_weight, + tosa_tensor_t client_bias, const int32_t client_input_zp, const int32_t client_weight_zp, tosa_tensor_t client_output) @@ -335,21 +337,26 @@ extern "C" // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); + tosa::TosaSerializationTensor* weight = translate_client_tensor(client_weight, "weight"); + tosa::TosaSerializationTensor* bias = translate_client_tensor(client_bias, "bias"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_FULLY_CONNECTED, - tosa::Attribute::Attribute_FullyConnectedAttribute, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator( + tosa::Op::Op_FULLY_CONNECTED, tosa::Attribute::Attribute_FullyConnectedAttribute, &attr, + { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, output }, - { input->GetName() }, { output->GetName() }); + tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, weight, bias, output }, + { input->GetName(), weight->GetName(), bias->GetName() }, + { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner; TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size)); + TOSA_RETURN_ON_ERROR(runner.setInput(weight->GetName(), client_weight.data, client_weight.size)); + TOSA_RETURN_ON_ERROR(runner.setInput(bias->GetName(), client_bias.data, client_bias.size)); // Execute TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); |