diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index e49fd09be0..aca5023f97 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1233,6 +1233,52 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) } } +void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"InstanceNormalizationQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 1); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + + if (inputTensorInfo.GetNumDimensions() > 4) + { + throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported."); + } + + ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + + // Check the supported data types + std::vector<DataType> supportedTypes = + { + DataType::Float32, + DataType::Float16 + }; + + ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); + ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + + ValidatePointer(m_Beta, descriptorName, "beta"); + ValidatePointer(m_Eps, descriptorName, "epsilon"); + ValidatePointer(m_Gamma, descriptorName, "gamma"); + + const TensorInfo& beta = m_Beta->GetTensorInfo(); + const TensorInfo& epsilon = m_Eps->GetTensorInfo(); + const TensorInfo& gamma = m_Gamma->GetTensorInfo(); + + ValidateTensorNumDimensions(beta, descriptorName, 1, "beta"); + ValidateTensorNumDimensions(epsilon, descriptorName, 1, "epsilon"); + ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma"); + + ValidateTensorDataTypesMatch(inputTensorInfo, beta, descriptorName, "input", "beta"); + ValidateTensorDataTypesMatch(inputTensorInfo, epsilon, descriptorName, "input", "epsilon"); + ValidateTensorDataTypesMatch(inputTensorInfo, gamma, descriptorName, "input", "gamma"); +} + void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"L2NormalizationQueueDescriptor"}; |