diff options
author | Kevin May <kevin.may@arm.com> | 2019-10-02 14:07:47 +0100 |
---|---|---|
committer | Kevin May <kevin.may@arm.com> | 2019-10-03 11:56:18 +0000 |
commit | ce5045a00485f8a8c35814c0781ccbcca5678e5c (patch) | |
tree | 7481fbdfd859f3edd24c1bf99830a0c89d6bb9ab /src/backends/backendsCommon/WorkloadData.cpp | |
parent | d47a064ab4c38559c6be931cb1771feb6e026ea4 (diff) | |
download | armnn-ce5045a00485f8a8c35814c0781ccbcca5678e5c.tar.gz |
IVGCVSW-3932 Add frontend for INSTANCE_NORMALIZATION
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: Ib152148ccd8d2733c617d0cf9402661fc6b71316
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index e49fd09be0..aca5023f97 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1233,6 +1233,52 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) } } +void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"InstanceNormalizationQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 1); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + + if (inputTensorInfo.GetNumDimensions() > 4) + { + throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported."); + } + + ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + + // Check the supported data types + std::vector<DataType> supportedTypes = + { + DataType::Float32, + DataType::Float16 + }; + + ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); + ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + + ValidatePointer(m_Beta, descriptorName, "beta"); + ValidatePointer(m_Eps, descriptorName, "epsilon"); + ValidatePointer(m_Gamma, descriptorName, "gamma"); + + const TensorInfo& beta = m_Beta->GetTensorInfo(); + const TensorInfo& epsilon = m_Eps->GetTensorInfo(); + const TensorInfo& gamma = m_Gamma->GetTensorInfo(); + + ValidateTensorNumDimensions(beta, descriptorName, 1, "beta"); + ValidateTensorNumDimensions(epsilon, descriptorName, 1, "epsilon"); + ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma"); + + ValidateTensorDataTypesMatch(inputTensorInfo, beta, descriptorName, "input", "beta"); + ValidateTensorDataTypesMatch(inputTensorInfo, epsilon, descriptorName, "input", "epsilon"); + ValidateTensorDataTypesMatch(inputTensorInfo, gamma, descriptorName, "input", "gamma"); +} + void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"L2NormalizationQueueDescriptor"}; |