aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Bentham <Matthew.Bentham@arm.com>2018-09-21 10:29:58 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:56 +0100
commit785df505a621a5b98084246056d80090073b950b (patch)
tree506f3de8a1c692ec90aaf81c88f593725579f8b2
parent10b4dfd8e9ccd7a03df7bb053ee1c644cb37f8ab (diff)
downloadarmnn-785df505a621a5b98084246056d80090073b950b.tar.gz
IVGCVSW-949 Simplify use of IntialiseArmComputeClTensorData
Change-Id: I556881e34f26e8152feaaba06d99828394872f58
-rw-r--r--src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp8
-rw-r--r--src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp4
-rw-r--r--src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp4
-rw-r--r--src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp4
-rw-r--r--src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp4
-rw-r--r--src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp19
-rw-r--r--src/backends/ClWorkloads/ClLstmFloatWorkload.cpp51
-rw-r--r--src/backends/ClWorkloads/ClWorkloadUtils.hpp12
8 files changed, 40 insertions, 66 deletions
diff --git a/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp b/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp
index 021734aaa6..d05349b819 100644
--- a/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp
@@ -68,10 +68,10 @@ ClBatchNormalizationFloatWorkload::ClBatchNormalizationFloatWorkload(
m_Gamma.get(),
m_Data.m_Parameters.m_Eps);
- InitializeArmComputeClTensorDataForFloatTypes(*m_Mean, m_Data.m_Mean);
- InitializeArmComputeClTensorDataForFloatTypes(*m_Variance, m_Data.m_Variance);
- InitializeArmComputeClTensorDataForFloatTypes(*m_Beta, m_Data.m_Beta);
- InitializeArmComputeClTensorDataForFloatTypes(*m_Gamma, m_Data.m_Gamma);
+ InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
+ InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
+ InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
+ InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
// Force Compute Library to perform the necessary copying and reshaping, after which
// delete all the input tensors that will no longer be needed
diff --git a/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp b/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
index 029f41d5dc..f0b9a46d60 100644
--- a/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
@@ -52,11 +52,11 @@ ClConvolution2dFloatWorkload::ClConvolution2dFloatWorkload(const Convolution2dQu
&output,
padStrideInfo);
- InitializeArmComputeClTensorDataForFloatTypes(*m_KernelTensor, m_Data.m_Weight);
+ InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
if (m_BiasTensor)
{
- InitializeArmComputeClTensorDataForFloatTypes(*m_BiasTensor, m_Data.m_Bias);
+ InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
}
// Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp b/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
index e6783b698a..c9f5eaa31d 100644
--- a/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
+++ b/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
@@ -51,11 +51,11 @@ ClConvolution2dUint8Workload::ClConvolution2dUint8Workload(const Convolution2dQu
&output,
padStrideInfo);
- InitialiseArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
+ InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
if (m_BiasTensor)
{
- InitialiseArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias->GetConstTensor<int32_t>());
+ InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
}
// Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp b/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp
index 635ae1f327..bc3b165490 100644
--- a/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClDepthwiseConvolutionFloatWorkload.cpp
@@ -17,11 +17,11 @@ ClDepthwiseConvolutionFloatWorkload::ClDepthwiseConvolutionFloatWorkload(
const WorkloadInfo& info)
: ClDepthwiseConvolutionBaseWorkload(descriptor, info)
{
- InitializeArmComputeClTensorDataForFloatTypes(*m_KernelTensor, m_Data.m_Weight);
+ InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
if (m_BiasTensor)
{
- InitializeArmComputeClTensorDataForFloatTypes(*m_BiasTensor, m_Data.m_Bias);
+ InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
}
m_DepthwiseConvolutionLayer->prepare();
diff --git a/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp b/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp
index af5836e908..4ea5590486 100644
--- a/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp
+++ b/src/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp
@@ -17,11 +17,11 @@ ClDepthwiseConvolutionUint8Workload::ClDepthwiseConvolutionUint8Workload(
const WorkloadInfo& info)
: ClDepthwiseConvolutionBaseWorkload(descriptor, info)
{
- InitialiseArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight->template GetConstTensor<uint8_t>());
+ InitializeArmComputeClTensorData(*m_KernelTensor, m_Data.m_Weight);
if (m_BiasTensor)
{
- InitialiseArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias->template GetConstTensor<int32_t>());
+ InitializeArmComputeClTensorData(*m_BiasTensor, m_Data.m_Bias);
}
m_DepthwiseConvolutionLayer->prepare();
diff --git a/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp b/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
index 8d2fd0e909..4686d1c8ee 100644
--- a/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
+++ b/src/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
@@ -68,26 +68,11 @@ ClFullyConnectedWorkload::ClFullyConnectedWorkload(const FullyConnectedQueueDesc
fc_info.transpose_weights = m_Data.m_Parameters.m_TransposeWeightMatrix;
m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
- // Allocate
- if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QuantisedAsymm8)
- {
- InitialiseArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
- }
- else
- {
- InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
- }
+ InitializeArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight);
if (m_BiasesTensor)
{
- if (m_Data.m_Bias->GetTensorInfo().GetDataType() == DataType::Signed32)
- {
- InitialiseArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias->GetConstTensor<int32_t>());
- }
- else
- {
- InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
- }
+ InitializeArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias);
}
// Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp b/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
index 09a34c2d02..8e2c875bab 100644
--- a/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
@@ -172,57 +172,40 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor,
armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
- InitialiseArmComputeClTensorData(*m_InputToForgetWeightsTensor,
- m_Data.m_InputToForgetWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_InputToCellWeightsTensor,
- m_Data.m_InputToCellWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_InputToOutputWeightsTensor,
- m_Data.m_InputToOutputWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor,
- m_Data.m_RecurrentToForgetWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToCellWeightsTensor,
- m_Data.m_RecurrentToCellWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor,
- m_Data.m_RecurrentToOutputWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_ForgetGateBiasTensor,
- m_Data.m_ForgetGateBias->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_CellBiasTensor,
- m_Data.m_CellBias->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_OutputGateBiasTensor,
- m_Data.m_OutputGateBias->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
+ InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
+ InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
+ InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
+ InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
+ InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
if (!m_Data.m_Parameters.m_CifgEnabled)
{
- InitialiseArmComputeClTensorData(*m_InputToInputWeightsTensor,
- m_Data.m_InputToInputWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToInputWeightsTensor,
- m_Data.m_RecurrentToInputWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
if (m_Data.m_CellToInputWeights != nullptr)
{
- InitialiseArmComputeClTensorData(*m_CellToInputWeightsTensor,
- m_Data.m_CellToInputWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
}
- InitialiseArmComputeClTensorData(*m_InputGateBiasTensor,
- m_Data.m_InputGateBias->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
}
if (m_Data.m_Parameters.m_ProjectionEnabled)
{
- InitialiseArmComputeClTensorData(*m_ProjectionWeightsTensor,
- m_Data.m_ProjectionWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
if (m_Data.m_ProjectionBias != nullptr)
{
- InitialiseArmComputeClTensorData(*m_ProjectionBiasTensor,
- m_Data.m_ProjectionBias->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
}
}
if (m_Data.m_Parameters.m_PeepholeEnabled)
{
- InitialiseArmComputeClTensorData(*m_CellToForgetWeightsTensor,
- m_Data.m_CellToForgetWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_CellToOutputWeightsTensor,
- m_Data.m_CellToOutputWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
+ InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
}
// Force Compute Library to perform the necessary copying and reshaping, after which
diff --git a/src/backends/ClWorkloads/ClWorkloadUtils.hpp b/src/backends/ClWorkloads/ClWorkloadUtils.hpp
index 6f1b155745..a10237cf40 100644
--- a/src/backends/ClWorkloads/ClWorkloadUtils.hpp
+++ b/src/backends/ClWorkloads/ClWorkloadUtils.hpp
@@ -42,8 +42,8 @@ void InitialiseArmComputeClTensorData(arm_compute::CLTensor& clTensor, const T*
CopyArmComputeClTensorData<T>(data, clTensor);
}
-inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor& clTensor,
- const ConstCpuTensorHandle *handle)
+inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor,
+ const ConstCpuTensorHandle* handle)
{
BOOST_ASSERT(handle);
switch(handle->GetTensorInfo().GetDataType())
@@ -54,8 +54,14 @@ inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor&
case DataType::Float32:
InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<float>());
break;
+ case DataType::QuantisedAsymm8:
+ InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>());
+ break;
+ case DataType::Signed32:
+ InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<int32_t>());
+ break;
default:
- BOOST_ASSERT_MSG(false, "Unexpected floating point type.");
+ BOOST_ASSERT_MSG(false, "Unexpected tensor type.");
}
};