aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2019-04-05 15:25:46 +0100
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-04-05 17:11:02 +0100
commiteff363d58992fb6384053259f9e1ee773f8cd4df (patch)
treee0bce8c4694ee15e016951f9168afbf9b75a9c79 /src/backends/backendsCommon/WorkloadData.cpp
parent1f88630874fe346cd0cca8d8e38e0fb96cc1a3f4 (diff)
downloadarmnn-eff363d58992fb6384053259f9e1ee773f8cd4df.tar.gz
IVGCVSW-2914 Add Switch Layer and no-op factory method
Change-Id: I6a6ece708a49e8a97c83a3e7fec11c88af1e1cfa Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp236
1 files changed, 131 insertions, 105 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 348c864863..b850a65acf 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -75,45 +75,23 @@ void ValidateTensorShapesMatch(const TensorInfo& first,
}
//---------------------------------------------------------------
-void ValidateNoInputs(const WorkloadInfo& workloadInfo, std::string const& descName)
+void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
{
- if (workloadInfo.m_InputTensorInfos.size() != 0)
+ if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
{
throw InvalidArgumentException(descName +
- ": Requires no inputs. " +
- to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided.");
- }
-}
-
-//---------------------------------------------------------------
-void ValidateSingleInput(const WorkloadInfo& workloadInfo, std::string const& descName)
-{
- if (workloadInfo.m_InputTensorInfos.size() != 1)
- {
- throw InvalidArgumentException(descName +
- ": Requires exactly one input. " +
- to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided." );
- }
-}
-
-//---------------------------------------------------------------
-void ValidateTwoInputs(const WorkloadInfo& workloadInfo, std::string const& descName)
-{
- if (workloadInfo.m_InputTensorInfos.size() != 2)
- {
- throw InvalidArgumentException(descName +
- ": Requires exactly two workloadInfo.m_InputTensorInfos. " +
+ ": Requires exactly " + to_string(expectedSize) + "input(s). " +
to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
}
}
//---------------------------------------------------------------
-void ValidateSingleOutput(const WorkloadInfo& workloadInfo, std::string const& descName)
+void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
{
- if (workloadInfo.m_OutputTensorInfos.size() != 1)
+ if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
{
throw InvalidArgumentException(descName +
- ": Requires exactly one output. " +
+ ": Requires exactly " + to_string(expectedSize) + " output(s). " +
to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
}
}
@@ -242,6 +220,18 @@ void ValidateTensorQuantizationMultiplier(const TensorInfo& inputTensor1, const
}
}
+//---------------------------------------------------------------
+void ValidateDataTypes(const TensorInfo& info,
+ const std::vector<armnn::DataType>& supportedTypes,
+ std::string const& descName)
+{
+ auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
+ if (iterator == supportedTypes.end())
+ {
+ throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
+ }
+}
+
} //namespace
void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
@@ -254,8 +244,8 @@ void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
//---------------------------------------------------------------
void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "MemCopyQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MemCopyQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MemCopyQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "MemCopyQueueDescriptor" , 1);
if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
{
@@ -299,8 +289,8 @@ void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ActivationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ActivationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ActivationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ActivationQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
"ActivationQueueDescriptor",
@@ -311,8 +301,8 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "SoftmaxQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "SoftmaxQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "SoftmaxQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "SoftmaxQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
@@ -324,7 +314,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "SplitterQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "SplitterQueueDescriptor", 1);
if (workloadInfo.m_OutputTensorInfos.size() <= 0)
{
@@ -372,7 +362,7 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void MergerQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleOutput(workloadInfo, "MergerQueueDescriptor");
+ ValidateNumOutputs(workloadInfo, "MergerQueueDescriptor", 1);
if (m_Inputs.size() <= 0)
{
@@ -444,8 +434,8 @@ void MergerQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "FullyConnectedQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "FullyConnectedQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "FullyConnectedQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "FullyConnectedQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FullyConnectedQueueDescriptor", 2, "output");
if (!(workloadInfo.m_InputTensorInfos[0].GetNumDimensions() == 2 ||
@@ -487,8 +477,8 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
//---------------------------------------------------------------
void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "NormalizationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "NormalizationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "NormalizationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "NormalizationQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
"NormalizationQueueDescriptor",
@@ -498,8 +488,8 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "AdditionQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "AdditionQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -513,8 +503,8 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "MultiplicationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MultiplicationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -526,8 +516,8 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "BatchNormalizationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "BatchNormalizationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
"BatchNormalizationQueueDescriptor",
@@ -554,8 +544,8 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf
void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "Convolution2dQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "Convolution2dQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "Convolution2dQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "Convolution2dQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "output");
@@ -580,8 +570,8 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1);
ValidateTensorNumDimensions(
workloadInfo.m_InputTensorInfos[0], "DepthwiseConvolution2dQueueDescriptor", 4, "input");
@@ -625,8 +615,8 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "PermuteQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "PermuteQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "PermuteQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "PermuteQueueDescriptor", 1);
const PermutationVector& mapping = m_Parameters.m_DimMappings;
@@ -650,8 +640,8 @@ void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "Pooling2dQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "Pooling2dQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "Pooling2dQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "Pooling2dQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "output");
@@ -659,8 +649,8 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ResizeBilinearQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ResizeBilinearQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "output");
@@ -694,8 +684,8 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "FakeQuantizationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "FakeQuantizationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "output");
@@ -713,8 +703,8 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "L2NormalizationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "L2NormalizationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "L2NormalizationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "L2NormalizationQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "output");
@@ -727,8 +717,8 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateNoInputs(workloadInfo, "ConstantQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ConstantQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ConstantQueueDescriptor", 0);
+ ValidateNumOutputs(workloadInfo, "ConstantQueueDescriptor", 1);
if (!m_LayerOutput)
{
@@ -744,8 +734,8 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ReshapeQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ReshapeQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ReshapeQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ReshapeQueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetNumElements() != workloadInfo.m_OutputTensorInfos[0].GetNumElements())
{
@@ -757,8 +747,8 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "SpaceToBatchNdQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "SpaceToBatchNdQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "output");
@@ -804,8 +794,8 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "FloorQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "FlootQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "FloorQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "FlootQueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0] != workloadInfo.m_OutputTensorInfos[0])
{
@@ -821,8 +811,8 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32)
{
@@ -843,8 +833,8 @@ void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo
void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float16)
{
@@ -864,8 +854,8 @@ void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo
void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "DivisionQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "DivisionQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -877,8 +867,8 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "SubtractionQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "SubtractionQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -890,8 +880,8 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "MaximumQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MaximumQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -903,8 +893,8 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "MeanQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MeanQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "MeanQueueDescriptor", 1);
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
@@ -929,8 +919,8 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "PadQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "PadQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "PadQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "PadQueueDescriptor", 1);
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
@@ -948,8 +938,8 @@ void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "QuantizeQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "QuantizeQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "QuantizeQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "QuantizeQueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32)
@@ -966,14 +956,14 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "BatchToSpaceNdQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "BatchToSpaceNdQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1);
}
void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "StridedSliceQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "StridedSliceQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "StridedSliceQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1);
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
const uint32_t rank = input.GetNumDimensions();
@@ -1015,8 +1005,8 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "MinimumQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MinimumQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -1028,14 +1018,14 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "DebugQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "DebugQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DebugQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "DebugQueueDescriptor", 1);
}
void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "EqualQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "EqualQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "EqualQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "EqualQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -1052,8 +1042,8 @@ void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "GreaterQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "GreaterQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "GreaterQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "GreaterQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -1070,8 +1060,8 @@ void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "RsqrtQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "RsqrtQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "RsqrtQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "RsqrtQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
"RsqrtQueueDescriptor",
@@ -1081,8 +1071,8 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "GatherQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "GatherQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "GatherQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "GatherQueueDescriptor", 1);
const TensorInfo& indices = workloadInfo.m_InputTensorInfos[1];
@@ -1102,7 +1092,7 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DetectionPostProcessQueueDescriptor", 2);
if (workloadInfo.m_OutputTensorInfos.size() != 4)
{
@@ -1155,8 +1145,8 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI
void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "DequantizeQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "DequantizeQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DequantizeQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "DequantizeQueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedAsymm8 &&
workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedSymm16)
@@ -1172,8 +1162,8 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MergeQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "MergeQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -1192,6 +1182,42 @@ void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output");
}
+void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateNumInputs(workloadInfo, "SwitchQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "SwitchQueueDescriptor", 2);
+
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "SwitchQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "SwitchQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "SwitchQueueDescriptor");
+
+ ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_OutputTensorInfos[0],
+ "SwitchQueueDescriptor",
+ "input0",
+ "output0");
+
+ ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_OutputTensorInfos[1],
+ "SwitchQueueDescriptor",
+ "input0",
+ "output1");
+}
+
void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
// This is internally generated so it should not need validation.