aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEFullyConnectedLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEFullyConnectedLayer.cpp13
1 files changed, 10 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 77028d96a2..4f858fb54b 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -61,7 +61,7 @@ NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> mem
}
void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output,
- FullyConnectedLayerInfo fc_info)
+ FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info)
{
// Perform validate step
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
@@ -76,7 +76,7 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
_impl->original_weights = weights;
_impl->is_prepared = false;
- _impl->op->configure(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), fc_info);
+ _impl->op->configure(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), fc_info, weights_info);
if(_impl->weights_manager != nullptr)
{
@@ -88,6 +88,13 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
_impl->workspace = manage_workspace<Tensor>(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->run_pack);
}
+Status NEFullyConnectedLayer::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights,
+ const ITensorInfo *biases, const ITensorInfo *output, const FullyConnectedLayerInfo &fc_info,
+ const WeightsInfo &weights_info)
+{
+ return cpu::CpuFullyConnected::has_opt_impl(expected_weight_format, input, weights, biases, output, fc_info, weights_info);
+}
+
Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
FullyConnectedLayerInfo fc_info)
{