aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorMatthew Jackson <matthew.jackson@arm.com>2019-07-04 14:59:16 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-07-10 12:06:51 +0000
commit2b8c1da565871b3e69567c2cfc46c8dcbef301aa (patch)
tree682327de212e273405cb257028568db997644c35 /src/backends/backendsCommon/WorkloadData.cpp
parentad5293a86e315049de36afd723dcd1a7e70681a7 (diff)
downloadarmnn-2b8c1da565871b3e69567c2cfc46c8dcbef301aa.tar.gz
IVGCVSW-3418 Add Arm NN front end support for the new Stack layer
* Added new StackLayer class * Made necessary changes to Descriptors, ILayerSupport, ILayerVisitor, etc. * Added unit tests Signed-off-by: Matthew Jackson <matthew.jackson@arm.com> Change-Id: Ieb97a928a342ffe1901c6058eb895711c358fd3d
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp85
1 files changed, 85 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 324c1debc0..878602391c 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -582,6 +582,91 @@ void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
}
//---------------------------------------------------------------
+void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateNumOutputs(workloadInfo, "StackQueueDescriptor", 1);
+
+ if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Must have the defined number of input tensors.");
+ }
+
+ // All inputs must have the same shape, which is defined in parameters
+ const TensorShape& inputShape = m_Parameters.m_InputShape;
+ for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
+ {
+ if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: All input tensor shapes "
+ "must match the defined shape.");
+ }
+ }
+
+ // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
+ // since the output tensor has an additional dimension.
+ if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Axis may not be greater "
+ "than the number of input dimensions.");
+ }
+
+ // Output shape must be as inferred from the input shape
+ const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
+ for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
+ {
+ if (outputShape[i] != inputShape[i])
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Output tensor must "
+ "match shape inferred from input tensor.");
+ }
+ }
+
+ if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Output tensor must "
+ "match shape inferred from input tensor.");
+ }
+
+ for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
+ {
+ if (outputShape[i] != inputShape[i-1])
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Output tensor must "
+ "match shape inferred from input tensor.");
+ }
+ }
+
+ // Check the supported data types
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::Boolean,
+ DataType::Signed32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "StackQueueDescriptor");
+
+ for (unsigned int i = 1; i < workloadInfo.m_InputTensorInfos.size(); ++i)
+ {
+ ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_InputTensorInfos[i],
+ "StackQueueDescriptor",
+ "InputTensor[0]",
+ "InputTensor[" + std::to_string(i) + "]");
+ }
+ ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_OutputTensorInfos[0],
+ "StackQueueDescriptor",
+ "InputTensor[0]",
+ "OutputTensor[0]");
+}
+
+//---------------------------------------------------------------
void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
ValidateNumInputs(workloadInfo, "FullyConnectedQueueDescriptor", 1);