aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl')
-rw-r--r--src/backends/cl/ClBackend.cpp9
-rw-r--r--src/backends/cl/workloads/ClFullyConnectedWorkload.cpp5
-rw-r--r--src/backends/cl/workloads/ClFullyConnectedWorkload.hpp2
3 files changed, 12 insertions, 4 deletions
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp
index 0fc5da78d1..018adec781 100644
--- a/src/backends/cl/ClBackend.cpp
+++ b/src/backends/cl/ClBackend.cpp
@@ -399,11 +399,18 @@ OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
{
FullyConnectedLayer* baseLayer = PolymorphicDowncast<FullyConnectedLayer*>(&base);
+ Optional<TensorInfo> biases;
+
+ if (baseLayer->GetParameters().m_BiasEnabled)
+ {
+ biases = baseLayer->m_Bias->GetTensorInfo();
+ }
+
arm_compute::Status status = ClFullyConnectedWorkloadValidate(
baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
baseLayer->m_Weight->GetTensorInfo(),
- baseLayer->m_Bias->GetTensorInfo(),
+ biases,
baseLayer->GetParameters(),
&activationDesc);
diff --git a/src/backends/cl/workloads/ClFullyConnectedWorkload.cpp b/src/backends/cl/workloads/ClFullyConnectedWorkload.cpp
index 3eb53e64b4..017f4fff6b 100644
--- a/src/backends/cl/workloads/ClFullyConnectedWorkload.cpp
+++ b/src/backends/cl/workloads/ClFullyConnectedWorkload.cpp
@@ -19,7 +19,7 @@ using namespace armcomputetensorutils;
arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& weights,
- const TensorInfo& biases,
+ const Optional<TensorInfo>& biases,
const FullyConnectedDescriptor& descriptor,
const ActivationDescriptor* activationDescriptor)
{
@@ -31,7 +31,8 @@ arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
arm_compute::TensorInfo* optionalAclBiases = nullptr;
if (descriptor.m_BiasEnabled)
{
- aclBiases = BuildArmComputeTensorInfo(biases);
+ ARMNN_ASSERT(biases.has_value());
+ aclBiases = BuildArmComputeTensorInfo(biases.value());
optionalAclBiases = &aclBiases;
}
diff --git a/src/backends/cl/workloads/ClFullyConnectedWorkload.hpp b/src/backends/cl/workloads/ClFullyConnectedWorkload.hpp
index 210757779f..3ab9f986a8 100644
--- a/src/backends/cl/workloads/ClFullyConnectedWorkload.hpp
+++ b/src/backends/cl/workloads/ClFullyConnectedWorkload.hpp
@@ -18,7 +18,7 @@ namespace armnn
arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& weights,
- const TensorInfo& biases,
+ const Optional<TensorInfo>& biases,
const FullyConnectedDescriptor& descriptor,
const ActivationDescriptor* activationDescriptor = nullptr);