diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 8847b4efbf..1dac498c11 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -593,9 +593,10 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3; - //inputChannels * channelMultiplier should be equal to outputChannels. + // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout + // inputChannels * channelMultiplier should be equal to outputChannels. const unsigned int numWeightChannelMultiplier = m_Weight->GetTensorInfo().GetShape()[0]; - const unsigned int numWeightInputChannels = m_Weight->GetTensorInfo().GetShape()[channelIndex]; + const unsigned int numWeightInputChannels = m_Weight->GetTensorInfo().GetShape()[1]; const unsigned int numWeightOutputChannels = workloadInfo.m_OutputTensorInfos[0].GetShape()[channelIndex]; if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels) { |