aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp46
1 files changed, 46 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index e49fd09be0..aca5023f97 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1233,6 +1233,52 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
}
}
+void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 1);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ if (inputTensorInfo.GetNumDimensions() > 4)
+ {
+ throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
+ }
+
+ ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ // Check the supported data types
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16
+ };
+
+ ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ ValidatePointer(m_Beta, descriptorName, "beta");
+ ValidatePointer(m_Eps, descriptorName, "epsilon");
+ ValidatePointer(m_Gamma, descriptorName, "gamma");
+
+ const TensorInfo& beta = m_Beta->GetTensorInfo();
+ const TensorInfo& epsilon = m_Eps->GetTensorInfo();
+ const TensorInfo& gamma = m_Gamma->GetTensorInfo();
+
+ ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
+ ValidateTensorNumDimensions(epsilon, descriptorName, 1, "epsilon");
+ ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
+
+ ValidateTensorDataTypesMatch(inputTensorInfo, beta, descriptorName, "input", "beta");
+ ValidateTensorDataTypesMatch(inputTensorInfo, epsilon, descriptorName, "input", "epsilon");
+ ValidateTensorDataTypesMatch(inputTensorInfo, gamma, descriptorName, "input", "gamma");
+}
+
void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
const std::string descriptorName{"L2NormalizationQueueDescriptor"};