aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp38
1 files changed, 23 insertions, 15 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index be0ac707a8..44a6a17b37 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -390,13 +390,6 @@ void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
"not set on tensor {1}.", descName, tensorName));
}
-
- if (quantizationDim.value() != 0)
- {
- throw InvalidArgumentException(fmt::format(
- "{0}: Quantization dimension for per-axis quantization expected to be 0 on tensor {1}, "
- "but got: {2}", descName, tensorName, quantizationDim.value()));
- }
}
void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
@@ -1386,17 +1379,32 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
- // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout
+ // Expected weight shape: [ 1, H, W, I*M ] - This shape does NOT depend on the data layout
// inputChannels * channelMultiplier should be equal to outputChannels.
- const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0];
- const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1];
- const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex];
- if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels)
+ const unsigned int numWeightOutputChannels = weightTensorInfo.GetShape()[3]; // I*M=Cout
+ const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
+ if (numWeightOutputChannels != numOutputChannels)
+ {
+ throw InvalidArgumentException(fmt::format(
+ "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
+ "But 4th dimension is not equal to Cout. Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
+ descriptorName,
+ numOutputChannels,
+ weightTensorInfo.GetShape()[0],
+ weightTensorInfo.GetShape()[1],
+ weightTensorInfo.GetShape()[2],
+ weightTensorInfo.GetShape()[3]));
+ }
+ if (weightTensorInfo.GetShape()[0] != 1)
{
throw InvalidArgumentException(fmt::format(
- "{0}: output_channels (provided {1}) should be equal to input_channels (provided {2}) "
- "multiplied by channel_multiplier (provided {3}).",
- descriptorName, numWeightOutputChannels, numWeightInputChannels, numWeightChannelMultiplier));
+ "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
+ "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
+ descriptorName,
+ weightTensorInfo.GetShape()[0],
+ weightTensorInfo.GetShape()[1],
+ weightTensorInfo.GetShape()[2],
+ weightTensorInfo.GetShape()[3]));
}
ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);