diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-06-20 14:28:19 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-21 09:54:30 +0000 |
commit | 639fb0437d1a5a8a6ea737fed5a16b554dfffead (patch) | |
tree | 5b89adc18c1a071d23747a28dcddcfca41e4d815 /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 713e95c8c531c5cecd804a7cecc8af745917315c (diff) | |
download | armnn-639fb0437d1a5a8a6ea737fed5a16b554dfffead.tar.gz |
IVGCVSW-3319 Add frontend support for TransposeConvolution2d Layer
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Ic06f63f1eff255e697facf319e2ac4c83d782e7c
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index adba86c79a..5ca492888f 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1800,4 +1800,45 @@ void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const "alpha"); } +void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 1); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], descriptorName, 4, "input"); + ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], descriptorName, 4, "output"); + + ValidatePointer(m_Weight, descriptorName, "weight"); + ValidateTensorNumDimensions(m_Weight->GetTensorInfo(), descriptorName, 4, "weight"); + + ValidateTensorDataType(m_Weight->GetTensorInfo(), + workloadInfo.m_InputTensorInfos[0].GetDataType(), + descriptorName, + "weight"); + + if (m_Parameters.m_BiasEnabled) + { + ValidateTensorNumDimensions(m_Bias->GetTensorInfo(), descriptorName, 1, "bias"); + + ValidateTensorDataType(m_Bias->GetTensorInfo(), + GetBiasDataType(workloadInfo.m_InputTensorInfos[0].GetDataType()), + descriptorName, "bias"); + + ValidateBiasTensorQuantization(m_Bias->GetTensorInfo(), + workloadInfo.m_InputTensorInfos[0], + m_Weight->GetTensorInfo(), + descriptorName); + } + + ValidateTensorQuantizationMultiplier(workloadInfo.m_InputTensorInfos[0], + m_Weight->GetTensorInfo(), + workloadInfo.m_OutputTensorInfos[0], + descriptorName, + "input", + "weights", + "output"); +} + } //namespace armnn |