aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2019-10-02 14:07:47 +0100
committerKevin May <kevin.may@arm.com>2019-10-03 11:56:18 +0000
commitce5045a00485f8a8c35814c0781ccbcca5678e5c (patch)
tree7481fbdfd859f3edd24c1bf99830a0c89d6bb9ab /src/backends/backendsCommon/WorkloadData.cpp
parentd47a064ab4c38559c6be931cb1771feb6e026ea4 (diff)
downloadarmnn-ce5045a00485f8a8c35814c0781ccbcca5678e5c.tar.gz
IVGCVSW-3932 Add frontend for INSTANCE_NORMALIZATION
Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: Ib152148ccd8d2733c617d0cf9402661fc6b71316
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp46
1 files changed, 46 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index e49fd09be0..aca5023f97 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1233,6 +1233,52 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
}
}
+void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 1);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ if (inputTensorInfo.GetNumDimensions() > 4)
+ {
+ throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
+ }
+
+ ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ // Check the supported data types
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16
+ };
+
+ ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ ValidatePointer(m_Beta, descriptorName, "beta");
+ ValidatePointer(m_Eps, descriptorName, "epsilon");
+ ValidatePointer(m_Gamma, descriptorName, "gamma");
+
+ const TensorInfo& beta = m_Beta->GetTensorInfo();
+ const TensorInfo& epsilon = m_Eps->GetTensorInfo();
+ const TensorInfo& gamma = m_Gamma->GetTensorInfo();
+
+ ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
+ ValidateTensorNumDimensions(epsilon, descriptorName, 1, "epsilon");
+ ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
+
+ ValidateTensorDataTypesMatch(inputTensorInfo, beta, descriptorName, "input", "beta");
+ ValidateTensorDataTypesMatch(inputTensorInfo, epsilon, descriptorName, "input", "epsilon");
+ ValidateTensorDataTypesMatch(inputTensorInfo, gamma, descriptorName, "input", "gamma");
+}
+
void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
const std::string descriptorName{"L2NormalizationQueueDescriptor"};