diff options
author | Jan Eilers <jan.eilers@arm.com> | 2021-06-02 12:01:25 +0100 |
---|---|---|
committer | Jan Eilers <jan.eilers@arm.com> | 2021-06-16 11:31:42 +0000 |
commit | 53ef79504b4c881c572735393c2eede5fa556c46 (patch) | |
tree | f6e0cd27c4d03075fa154074c5b12d7c8c3149f7 /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 77fe76bfa8cb798943821d1f3e432c228e1cdee3 (diff) | |
download | armnn-53ef79504b4c881c572735393c2eede5fa556c46.tar.gz |
IVGCVSW-5826 Change weights layout for depthwise to [1,H,W,I*M]
* This change is necessary because tflite uses a [1,H,W,I*M] format
and uses the I*M dimension for per axis quantization. Our previous
layout [M,I,H,W] can't handle the correlating quantization scales.
* Updates Onnx-, TfLiteParser and TfliteDelegate
* Updates the CpuRef, CpuAcc and GpuAcc backends
* Adjusts unit tests
* Adds test to ensure models with old layout can still be read and
executed
* Adds conversion function to previous layout [1,H,W,I*M] --> [M,I,H,W]
which can be used by backend developers
!android-nn-driver:5553
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: Ifef23368b8c3702cf315a5838d214f7dc13c0152
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 38 |
1 files changed, 23 insertions, 15 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index be0ac707a8..44a6a17b37 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -390,13 +390,6 @@ void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo, throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization " "not set on tensor {1}.", descName, tensorName)); } - - if (quantizationDim.value() != 0) - { - throw InvalidArgumentException(fmt::format( - "{0}: Quantization dimension for per-axis quantization expected to be 0 on tensor {1}, " - "but got: {2}", descName, tensorName, quantizationDim.value())); - } } void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo, @@ -1386,17 +1379,32 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3; - // Expected weight shape: [ M, I, H, W ] - This shape does NOT depend on the data layout + // Expected weight shape: [ 1, H, W, I*M ] - This shape does NOT depend on the data layout // inputChannels * channelMultiplier should be equal to outputChannels. - const unsigned int numWeightChannelMultiplier = weightTensorInfo.GetShape()[0]; - const unsigned int numWeightInputChannels = weightTensorInfo.GetShape()[1]; - const unsigned int numWeightOutputChannels = outputTensorInfo.GetShape()[channelIndex]; - if (numWeightChannelMultiplier * numWeightInputChannels != numWeightOutputChannels) + const unsigned int numWeightOutputChannels = weightTensorInfo.GetShape()[3]; // I*M=Cout + const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex]; + if (numWeightOutputChannels != numOutputChannels) + { + throw InvalidArgumentException(fmt::format( + "{0}: The weight format in armnn is expected to be [1, H, W, Cout]." + "But 4th dimension is not equal to Cout. Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]", + descriptorName, + numOutputChannels, + weightTensorInfo.GetShape()[0], + weightTensorInfo.GetShape()[1], + weightTensorInfo.GetShape()[2], + weightTensorInfo.GetShape()[3])); + } + if (weightTensorInfo.GetShape()[0] != 1) { throw InvalidArgumentException(fmt::format( - "{0}: output_channels (provided {1}) should be equal to input_channels (provided {2}) " - "multiplied by channel_multiplier (provided {3}).", - descriptorName, numWeightOutputChannels, numWeightInputChannels, numWeightChannelMultiplier)); + "{0}: The weight format in armnn is expected to be [1, H, W, Cout]." + "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]", + descriptorName, + weightTensorInfo.GetShape()[0], + weightTensorInfo.GetShape()[1], + weightTensorInfo.GetShape()[2], + weightTensorInfo.GetShape()[3])); } ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName); |