aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonConvolution2dBaseWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonConvolution2dBaseWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonConvolution2dBaseWorkload.cpp25
1 files changed, 1 insertions, 24 deletions
diff --git a/src/backends/neon/workloads/NeonConvolution2dBaseWorkload.cpp b/src/backends/neon/workloads/NeonConvolution2dBaseWorkload.cpp
index 8da3e47249..b11d10fd2f 100644
--- a/src/backends/neon/workloads/NeonConvolution2dBaseWorkload.cpp
+++ b/src/backends/neon/workloads/NeonConvolution2dBaseWorkload.cpp
@@ -109,30 +109,8 @@ NeonConvolution2dBaseWorkload<dataTypes...>::NeonConvolution2dBaseWorkload(
}
BOOST_ASSERT(m_ConvolutionLayer);
- armnn::DataType dataType = m_Data.m_Weight->GetTensorInfo().GetDataType();
+ InitializeArmComputeTensorData(*m_KernelTensor, m_Data.m_Weight);
- switch (dataType)
- {
- case DataType::Float16:
- {
- InitialiseArmComputeTensorData(*m_KernelTensor, m_Data.m_Weight->template GetConstTensor<Half>());
- break;
- }
- case DataType::Float32:
- {
- InitialiseArmComputeTensorData(*m_KernelTensor, m_Data.m_Weight->template GetConstTensor<float>());
- break;
- }
- case DataType::QuantisedAsymm8:
- {
- InitialiseArmComputeTensorData(*m_KernelTensor, m_Data.m_Weight->template GetConstTensor<uint8_t>());
- break;
- }
- default:
- {
- BOOST_ASSERT_MSG(false, "Unknown DataType.");
- }
- }
}
template<armnn::DataType... dataTypes>
@@ -147,4 +125,3 @@ template class NeonConvolution2dBaseWorkload<armnn::DataType::Float16, armnn::Da
template class NeonConvolution2dBaseWorkload<armnn::DataType::QuantisedAsymm8>;
} //namespace armnn
-