ArmNN
 21.08
BatchNormalizationQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for BatchNormalizationQueueDescriptor:
QueueDescriptorWithParameters< BatchNormalizationDescriptor > QueueDescriptor

Public Member Functions

 BatchNormalizationQueueDescriptor ()
 
void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptor
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 
template<typename T >
const T * GetAdditionalInformation () const
 

Public Attributes

const ConstTensorHandlem_Mean
 
const ConstTensorHandlem_Variance
 
const ConstTensorHandlem_Beta
 
const ConstTensorHandlem_Gamma
 
- Public Attributes inherited from QueueDescriptorWithParameters< BatchNormalizationDescriptor >
BatchNormalizationDescriptor m_Parameters
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 
void * m_AdditionalInfoObject
 

Additional Inherited Members

- Protected Member Functions inherited from QueueDescriptorWithParameters< BatchNormalizationDescriptor >
 ~QueueDescriptorWithParameters ()=default
 
 QueueDescriptorWithParameters ()=default
 
 QueueDescriptorWithParameters (QueueDescriptorWithParameters const &)=default
 
QueueDescriptorWithParametersoperator= (QueueDescriptorWithParameters const &)=default
 
- Protected Member Functions inherited from QueueDescriptor
 ~QueueDescriptor ()=default
 
 QueueDescriptor ()
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 310 of file WorkloadData.hpp.

Constructor & Destructor Documentation

◆ BatchNormalizationQueueDescriptor()

Definition at line 312 of file WorkloadData.hpp.

313  : m_Mean(nullptr)
314  , m_Variance(nullptr)
315  , m_Beta(nullptr)
316  , m_Gamma(nullptr)
317  {
318  }
const ConstTensorHandle * m_Variance

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 1205 of file WorkloadData.cpp.

References armnn::BFloat16, armnn::Float16, armnn::Float32, WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, armnn::QAsymmS8, armnn::QAsymmU8, and armnn::QSymmS16.

1206 {
1207  const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
1208 
1209  ValidateNumInputs(workloadInfo, descriptorName, 1);
1210  ValidateNumOutputs(workloadInfo, descriptorName, 1);
1211 
1212  const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1213  const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1214 
1215  std::vector<DataType> supportedTypes =
1216  {
1223  };
1224 
1225  ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1226  ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1227 
1228  ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1229  ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1230 
1231  ValidatePointer(m_Mean, descriptorName, "mean");
1232  ValidatePointer(m_Variance, descriptorName, "variance");
1233  ValidatePointer(m_Beta, descriptorName, "beta");
1234  ValidatePointer(m_Gamma, descriptorName, "gamma");
1235 
1236  const TensorInfo& mean = m_Mean->GetTensorInfo();
1237  const TensorInfo& variance = m_Variance->GetTensorInfo();
1238  const TensorInfo& beta = m_Beta->GetTensorInfo();
1239  const TensorInfo& gamma = m_Gamma->GetTensorInfo();
1240 
1241  ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1242  ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1243  ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1244  ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
1245 
1246  ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1247  ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1248  ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
1249 }
const ConstTensorHandle * m_Variance
const TensorInfo & GetTensorInfo() const
std::vector< TensorInfo > m_InputTensorInfos
std::vector< TensorInfo > m_OutputTensorInfos

Member Data Documentation

◆ m_Beta

◆ m_Gamma

◆ m_Mean

◆ m_Variance

const ConstTensorHandle* m_Variance

The documentation for this struct was generated from the following files: