diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index d89b5899ba..7a46741964 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1382,7 +1382,13 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa { const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"}; - ValidateNumInputs(workloadInfo, descriptorName, 1); + uint32_t numInputs = 2; + if (m_Parameters.m_BiasEnabled) + { + numInputs = 3; + } + + ValidateNumInputs(workloadInfo, descriptorName, numInputs); ValidateNumOutputs(workloadInfo, descriptorName, 1); const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; @@ -1391,9 +1397,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input"); ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output"); - ValidatePointer(m_Weight, descriptorName, "weight"); - - const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo(); + const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1]; ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight"); if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 ) @@ -1447,9 +1451,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa Optional<TensorInfo> optionalBiasTensorInfo; if (m_Parameters.m_BiasEnabled) { - ValidatePointer(m_Bias, descriptorName, "bias"); - - optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo()); + optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]); const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value(); ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName); |