From eff363d58992fb6384053259f9e1ee773f8cd4df Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Fri, 5 Apr 2019 15:25:46 +0100 Subject: IVGCVSW-2914 Add Switch Layer and no-op factory method Change-Id: I6a6ece708a49e8a97c83a3e7fec11c88af1e1cfa Signed-off-by: Sadik Armagan Signed-off-by: Aron Virginas-Tar --- src/backends/backendsCommon/LayerSupportBase.cpp | 9 + src/backends/backendsCommon/LayerSupportBase.hpp | 6 + src/backends/backendsCommon/WorkloadData.cpp | 236 ++++++++++++--------- src/backends/backendsCommon/WorkloadData.hpp | 5 + src/backends/backendsCommon/WorkloadFactory.cpp | 19 ++ src/backends/backendsCommon/WorkloadFactory.hpp | 3 + .../test/IsLayerSupportedTestImpl.hpp | 2 + 7 files changed, 175 insertions(+), 105 deletions(-) (limited to 'src/backends') diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index fc2d502fbd..6cad7b93ab 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -397,4 +397,13 @@ bool LayerSupportBase::IsSubtractionSupported(const TensorInfo& input0, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsSwitchSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output0, + const TensorInfo& output1, + Optional reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + } // namespace armnn diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 7c38b67379..3c39f8919d 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -246,6 +246,12 @@ public: const TensorInfo& input1, const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; + + bool IsSwitchSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output0, + const TensorInfo& output1, + Optional reasonIfUnsupported = EmptyOptional()) const override; }; } // namespace armnn diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 348c864863..b850a65acf 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -75,45 +75,23 @@ void ValidateTensorShapesMatch(const TensorInfo& first, } //--------------------------------------------------------------- -void ValidateNoInputs(const WorkloadInfo& workloadInfo, std::string const& descName) +void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize) { - if (workloadInfo.m_InputTensorInfos.size() != 0) + if (workloadInfo.m_InputTensorInfos.size() != expectedSize) { throw InvalidArgumentException(descName + - ": Requires no inputs. " + - to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided."); - } -} - -//--------------------------------------------------------------- -void ValidateSingleInput(const WorkloadInfo& workloadInfo, std::string const& descName) -{ - if (workloadInfo.m_InputTensorInfos.size() != 1) - { - throw InvalidArgumentException(descName + - ": Requires exactly one input. " + - to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided." ); - } -} - -//--------------------------------------------------------------- -void ValidateTwoInputs(const WorkloadInfo& workloadInfo, std::string const& descName) -{ - if (workloadInfo.m_InputTensorInfos.size() != 2) - { - throw InvalidArgumentException(descName + - ": Requires exactly two workloadInfo.m_InputTensorInfos. " + + ": Requires exactly " + to_string(expectedSize) + "input(s). " + to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided."); } } //--------------------------------------------------------------- -void ValidateSingleOutput(const WorkloadInfo& workloadInfo, std::string const& descName) +void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize) { - if (workloadInfo.m_OutputTensorInfos.size() != 1) + if (workloadInfo.m_OutputTensorInfos.size() != expectedSize) { throw InvalidArgumentException(descName + - ": Requires exactly one output. " + + ": Requires exactly " + to_string(expectedSize) + " output(s). " + to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided."); } } @@ -242,6 +220,18 @@ void ValidateTensorQuantizationMultiplier(const TensorInfo& inputTensor1, const } } +//--------------------------------------------------------------- +void ValidateDataTypes(const TensorInfo& info, + const std::vector& supportedTypes, + std::string const& descName) +{ + auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType()); + if (iterator == supportedTypes.end()) + { + throw InvalidArgumentException(descName + ": " + " Tensor type is not supported."); + } +} + } //namespace void QueueDescriptor::ValidateInputsOutputs(const std::string& descName, @@ -254,8 +244,8 @@ void QueueDescriptor::ValidateInputsOutputs(const std::string& descName, //--------------------------------------------------------------- void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "MemCopyQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MemCopyQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MemCopyQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MemCopyQueueDescriptor" , 1); if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size()) { @@ -299,8 +289,8 @@ void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ActivationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ActivationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "ActivationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ActivationQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "ActivationQueueDescriptor", @@ -311,8 +301,8 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "SoftmaxQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "SoftmaxQueueDescriptor"); + ValidateNumInputs(workloadInfo, "SoftmaxQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "SoftmaxQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], @@ -324,7 +314,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "SplitterQueueDescriptor"); + ValidateNumInputs(workloadInfo, "SplitterQueueDescriptor", 1); if (workloadInfo.m_OutputTensorInfos.size() <= 0) { @@ -372,7 +362,7 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void MergerQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleOutput(workloadInfo, "MergerQueueDescriptor"); + ValidateNumOutputs(workloadInfo, "MergerQueueDescriptor", 1); if (m_Inputs.size() <= 0) { @@ -444,8 +434,8 @@ void MergerQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "FullyConnectedQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "FullyConnectedQueueDescriptor"); + ValidateNumInputs(workloadInfo, "FullyConnectedQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "FullyConnectedQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FullyConnectedQueueDescriptor", 2, "output"); if (!(workloadInfo.m_InputTensorInfos[0].GetNumDimensions() == 2 || @@ -487,8 +477,8 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c //--------------------------------------------------------------- void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "NormalizationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "NormalizationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "NormalizationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "NormalizationQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "NormalizationQueueDescriptor", @@ -498,8 +488,8 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "AdditionQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "AdditionQueueDescriptor"); + ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -513,8 +503,8 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const //--------------------------------------------------------------- void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "MultiplicationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MultiplicationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -526,8 +516,8 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "BatchNormalizationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "BatchNormalizationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "BatchNormalizationQueueDescriptor", @@ -554,8 +544,8 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "Convolution2dQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "Convolution2dQueueDescriptor"); + ValidateNumInputs(workloadInfo, "Convolution2dQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "Convolution2dQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "output"); @@ -580,8 +570,8 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1); ValidateTensorNumDimensions( workloadInfo.m_InputTensorInfos[0], "DepthwiseConvolution2dQueueDescriptor", 4, "input"); @@ -625,8 +615,8 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "PermuteQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "PermuteQueueDescriptor"); + ValidateNumInputs(workloadInfo, "PermuteQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "PermuteQueueDescriptor", 1); const PermutationVector& mapping = m_Parameters.m_DimMappings; @@ -650,8 +640,8 @@ void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "Pooling2dQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "Pooling2dQueueDescriptor"); + ValidateNumInputs(workloadInfo, "Pooling2dQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "Pooling2dQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "output"); @@ -659,8 +649,8 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ResizeBilinearQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ResizeBilinearQueueDescriptor"); + ValidateNumInputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "output"); @@ -694,8 +684,8 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "FakeQuantizationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "FakeQuantizationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "output"); @@ -713,8 +703,8 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "L2NormalizationQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "L2NormalizationQueueDescriptor"); + ValidateNumInputs(workloadInfo, "L2NormalizationQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "L2NormalizationQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "output"); @@ -727,8 +717,8 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateNoInputs(workloadInfo, "ConstantQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ConstantQueueDescriptor"); + ValidateNumInputs(workloadInfo, "ConstantQueueDescriptor", 0); + ValidateNumOutputs(workloadInfo, "ConstantQueueDescriptor", 1); if (!m_LayerOutput) { @@ -744,8 +734,8 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ReshapeQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ReshapeQueueDescriptor"); + ValidateNumInputs(workloadInfo, "ReshapeQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ReshapeQueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetNumElements() != workloadInfo.m_OutputTensorInfos[0].GetNumElements()) { @@ -757,8 +747,8 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "SpaceToBatchNdQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "SpaceToBatchNdQueueDescriptor"); + ValidateNumInputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1); ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "output"); @@ -804,8 +794,8 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "FloorQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "FlootQueueDescriptor"); + ValidateNumInputs(workloadInfo, "FloorQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "FlootQueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0] != workloadInfo.m_OutputTensorInfos[0]) { @@ -821,8 +811,8 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor"); + ValidateNumInputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32) { @@ -843,8 +833,8 @@ void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor"); - ValidateSingleOutput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor"); + ValidateNumInputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float16) { @@ -864,8 +854,8 @@ void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "DivisionQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "DivisionQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -877,8 +867,8 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "SubtractionQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "SubtractionQueueDescriptor"); + ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -890,8 +880,8 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "MaximumQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MaximumQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -903,8 +893,8 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "MeanQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MeanQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MeanQueueDescriptor", 1); const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; @@ -929,8 +919,8 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "PadQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "PadQueueDescriptor"); + ValidateNumInputs(workloadInfo, "PadQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "PadQueueDescriptor", 1); const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; @@ -948,8 +938,8 @@ void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "QuantizeQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "QuantizeQueueDescriptor"); + ValidateNumInputs(workloadInfo, "QuantizeQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "QuantizeQueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32) @@ -966,14 +956,14 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "BatchToSpaceNdQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "BatchToSpaceNdQueueDescriptor"); + ValidateNumInputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1); } void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "StridedSliceQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "StridedSliceQueueDescriptor"); + ValidateNumInputs(workloadInfo, "StridedSliceQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1); const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; const uint32_t rank = input.GetNumDimensions(); @@ -1015,8 +1005,8 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "MinimumQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MinimumQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -1028,14 +1018,14 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "DebugQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "DebugQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DebugQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "DebugQueueDescriptor", 1); } void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "EqualQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "EqualQueueDescriptor"); + ValidateNumInputs(workloadInfo, "EqualQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "EqualQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -1052,8 +1042,8 @@ void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "GreaterQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "GreaterQueueDescriptor"); + ValidateNumInputs(workloadInfo, "GreaterQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "GreaterQueueDescriptor", 1); ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -1070,8 +1060,8 @@ void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "RsqrtQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "RsqrtQueueDescriptor"); + ValidateNumInputs(workloadInfo, "RsqrtQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "RsqrtQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_OutputTensorInfos[0], "RsqrtQueueDescriptor", @@ -1081,8 +1071,8 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "GatherQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "GatherQueueDescriptor"); + ValidateNumInputs(workloadInfo, "GatherQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "GatherQueueDescriptor", 1); const TensorInfo& indices = workloadInfo.m_InputTensorInfos[1]; @@ -1102,7 +1092,7 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DetectionPostProcessQueueDescriptor", 2); if (workloadInfo.m_OutputTensorInfos.size() != 4) { @@ -1155,8 +1145,8 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateSingleInput(workloadInfo, "DequantizeQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "DequantizeQueueDescriptor"); + ValidateNumInputs(workloadInfo, "DequantizeQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "DequantizeQueueDescriptor", 1); if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedAsymm8 && workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedSymm16) @@ -1172,8 +1162,8 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor"); - ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor"); + ValidateNumInputs(workloadInfo, "MergeQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "MergeQueueDescriptor", 1); ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], @@ -1192,6 +1182,42 @@ void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output"); } +void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateNumInputs(workloadInfo, "SwitchQueueDescriptor", 2); + ValidateNumOutputs(workloadInfo, "SwitchQueueDescriptor", 2); + + std::vector supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "SwitchQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "SwitchQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "SwitchQueueDescriptor"); + + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[0], + "SwitchQueueDescriptor", + "input0", + "output0"); + + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[1], + "SwitchQueueDescriptor", + "input0", + "output1"); +} + void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { // This is internally generated so it should not need validation. diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 1bf735288d..1b5f86dde7 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -426,4 +426,9 @@ struct MergeQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; +struct SwitchQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } //namespace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 4ea3ea9f9b..d9774b063d 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -729,6 +729,19 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::Switch: + { + const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); + const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo(); + const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo(); + result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType), + OverrideDataType(input1, dataType), + OverrideDataType(output0, dataType), + OverrideDataType(output1, dataType), + reason); + break; + } case LayerType::Mean: { auto cLayer = boost::polymorphic_downcast(&layer); @@ -1041,4 +1054,10 @@ std::unique_ptr IWorkloadFactory::CreateSubtraction(const Subtraction return std::unique_ptr(); } +std::unique_ptr IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr(); +} + } diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 889bc9d595..5c07b3af6f 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -177,6 +177,9 @@ public: virtual std::unique_ptr CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateSwitch(const SwitchQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; }; } //namespace armnn diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index 0588607a82..a7d7b094cf 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -402,6 +402,8 @@ DECLARE_LAYER_POLICY_2_PARAM(StridedSlice) DECLARE_LAYER_POLICY_1_PARAM(Subtraction) +DECLARE_LAYER_POLICY_1_PARAM(Switch) + // Generic implementation to get the number of input slots for a given layer type; template -- cgit v1.2.1