aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp30
1 files changed, 8 insertions, 22 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 1c18551679..3f5972dab6 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -36,7 +36,11 @@ const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> typ
return info;
}
- return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
+ return TensorInfo(info.GetShape(),
+ type.value(),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset(),
+ info.IsConstant());
}
} // anonymous namespace
@@ -364,16 +368,7 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
TensorInfo weightsInfo;
const TensorInfo* weightsInfoPtr = nullptr;
- if (descriptor.m_ConstantWeights)
- {
- ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
- weightsInfo = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
- }
- else
- {
- weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
-
- }
+ weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
weightsInfoPtr = &weightsInfo;
TensorInfo biasInfo;
@@ -385,17 +380,8 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
if (descriptor.m_BiasEnabled)
{
- if(descriptor.m_ConstantWeights)
- {
- ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
- biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
- biasInfoPtr = &biasInfo;
- }
- else
- {
- biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
- biasInfoPtr = &biasInfo;
- }
+ biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
+ biasInfoPtr = &biasInfo;
}
else
{