aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKeith Davis <keith.davis@arm.com>2020-02-11 16:51:50 +0000
committerJames Conroy <james.conroy@arm.com>2020-02-17 21:53:29 +0000
commit0c2eeac6347533a1d3d456aebea492f5123388f3 (patch)
treef218fc236137791c491b680dfd24fb9706c171a6
parent4c3c1f486ab775eacb1f6455f8468f9be2c3e4f7 (diff)
downloadarmnn-0c2eeac6347533a1d3d456aebea492f5123388f3.tar.gz
IVGCVSW-4436 Add ExecuteNetwork test for mobilenet_v2_int8
* Add QAsymmS8 to QueueDescriptor supportedTypes * Add QSymmS8/QAsymmS8 to RefLayerSupport supportedTypes * Some additional comments and refactoring Change-Id: I8567314452e6e8f6f69cb6e458ee147d3fc92fab Signed-off-by: Keith Davis <keith.davis@arm.com>
-rw-r--r--include/armnn/TypesUtils.hpp9
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp23
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp28
-rw-r--r--src/backends/reference/RefLayerSupport.cpp43
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp25
-rw-r--r--src/backends/reference/workloads/RefDebugWorkload.hpp12
6 files changed, 85 insertions, 55 deletions
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index 59beb33144..bf54c15ef8 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -169,6 +169,7 @@ constexpr const char* GetDataTypeName(DataType dataType)
case DataType::Float16: return "Float16";
case DataType::Float32: return "Float32";
case DataType::QAsymmU8: return "QAsymmU8";
+ case DataType::QAsymmS8: return "QAsymmS8";
case DataType::QSymmS8: return "QSymmS8";
ARMNN_NO_DEPRECATE_WARN_BEGIN
case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
@@ -233,17 +234,21 @@ constexpr bool IsQuantizedType()
return std::is_integral<T>::value;
}
-constexpr bool IsQuantizedType(DataType dataType)
+constexpr bool IsQuantized8BitType(DataType dataType)
{
ARMNN_NO_DEPRECATE_WARN_BEGIN
return dataType == DataType::QAsymmU8 ||
dataType == DataType::QAsymmS8 ||
dataType == DataType::QSymmS8 ||
- dataType == DataType::QSymmS16 ||
dataType == DataType::QuantizedSymm8PerAxis;
ARMNN_NO_DEPRECATE_WARN_END
}
+constexpr bool IsQuantizedType(DataType dataType)
+{
+ return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType);
+}
+
inline std::ostream& operator<<(std::ostream& os, Status stat)
{
os << GetStatusAsCString(stat);
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 560cdf1779..593f3eb02d 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -301,7 +301,8 @@ void CalcPadding(uint32_t inputSize,
}
}
-armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes)
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes,
+ const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3})
{
armnn::DataType type;
CHECK_TENSOR_PTR(tensorPtr);
@@ -317,10 +318,12 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::
case tflite::TensorType_INT8:
if (tensorPtr->quantization->zero_point.size() == 1 && tensorPtr->quantization->zero_point[0] != 0)
{
+ // Per-tensor
type = armnn::DataType::QAsymmS8;
}
else
{
+ // Per-channel
type = armnn::DataType::QSymmS8;
}
break;
@@ -388,12 +391,13 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::
tensorPtr->quantization->scale.end(),
std::back_inserter(quantizationScales));
- // QSymm Per-axis
+ // QSymmS8 Per-axis
armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
safeShape.data(),
type,
quantizationScales,
- boost::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension));
+ dimensionMappings[boost::numeric_cast<unsigned int>(
+ tensorPtr->quantization->quantized_dimension)]);
return result;
}
@@ -409,10 +413,11 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::
}
}
-armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr,
+ const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3})
{
auto const & dimensions = AsUnsignedVector(tensorPtr->shape);
- return ToTensorInfo(tensorPtr, dimensions);
+ return ToTensorInfo(tensorPtr, dimensions, dimensionMappings);
}
template<typename T>
@@ -905,8 +910,11 @@ void TfLiteParser::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorInd
desc.m_DilationX = CHECKED_NON_NEGATIVE(options->dilation_w_factor);
desc.m_DilationY = CHECKED_NON_NEGATIVE(options->dilation_h_factor);
+ // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W])
+ PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
+
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
- armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]);
+ armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1], permutationVector);
// Assuming input is NHWC
unsigned int inputHeight = inputTensorInfo.GetShape()[1];
@@ -922,9 +930,6 @@ void TfLiteParser::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorInd
inputTensorInfo.GetShape()[3],
filterTensorInfo.GetShape()[3] / inputTensorInfo.GetShape()[3] });
- // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W])
- PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
-
CalcPadding(inputHeight, filterHeight, desc.m_StrideY,
desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, options->padding);
CalcPadding(inputWidth, filterWidth, desc.m_StrideX,
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index ebaf961fe8..fea72256a1 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -30,6 +30,8 @@ DataType GetBiasDataType(DataType inputDataType)
return DataType::Float16;
case DataType::Float32:
return DataType::Float32;
+ case DataType::QAsymmS8:
+ return DataType::Signed32;
case DataType::QAsymmU8:
return DataType::Signed32;
case DataType::QSymmS8:
@@ -357,12 +359,13 @@ void ValidateWeightDataType(const TensorInfo& inputInfo,
const std::string& descName)
{
const DataType inputType = inputInfo.GetDataType();
- if (inputType == DataType::QAsymmU8)
+ if (IsQuantized8BitType(inputType))
{
ARMNN_NO_DEPRECATE_WARN_BEGIN
const std::vector<DataType> validTypes =
{
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS8,
DataType::QuantizedSymm8PerAxis // deprecated
};
@@ -420,8 +423,7 @@ void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
const DataType inputDataType = inputInfo.GetDataType();
const DataType outputDataType = outputInfo.GetDataType();
- const bool canHavePerAxisQuantization = (inputDataType == DataType::QSymmS8 ||
- inputDataType == DataType::QAsymmU8) && inputDataType == outputDataType;
+ const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
if (!canHavePerAxisQuantization)
{
@@ -599,6 +601,7 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float16,
DataType::Float32,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -684,6 +687,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float16,
DataType::Float32,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -1038,10 +1042,11 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
std::vector<DataType> supportedTypes =
{
DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::QSymmS8,
- DataType::Float16
+ DataType::QSymmS8
};
ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -1181,6 +1186,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
{
DataType::Float32,
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS16,
DataType::QSymmS8,
DataType::Float16
@@ -1255,6 +1261,7 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
{
DataType::Float32,
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS16,
DataType::Float16
};
@@ -1309,6 +1316,7 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
DataType::Float32,
DataType::Float16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -1560,9 +1568,10 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::Float32,
DataType::Float16,
DataType::Signed32,
+ DataType::QSymmS16,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
- DataType::QSymmS8,
- DataType::QSymmS16
+ DataType::QSymmS8
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
@@ -2208,10 +2217,7 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
- if (outputTensorInfo.GetDataType() != DataType::QAsymmS8 &&
- outputTensorInfo.GetDataType() != DataType::QAsymmU8 &&
- outputTensorInfo.GetDataType() != DataType::QSymmS8 &&
- outputTensorInfo.GetDataType() != DataType::QSymmS16)
+ if (!IsQuantizedType(outputTensorInfo.GetDataType()))
{
throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
}
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index c60348e529..bba83e23d4 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -4,15 +4,11 @@
//
#include "RefLayerSupport.hpp"
-#include "RefBackendId.hpp"
+#include <armnn/TypesUtils.hpp>
#include <armnn/Types.hpp>
#include <armnn/Descriptors.hpp>
-#include <armnn/BackendRegistry.hpp>
-#include <armnnUtils/DataLayoutIndexed.hpp>
-
-#include <InternalTypes.hpp>
#include <LayerSupportCommon.hpp>
#include <backendsCommon/LayerSupportRules.hpp>
@@ -21,7 +17,6 @@
#include <boost/core/ignore_unused.hpp>
#include <vector>
-#include <algorithm>
#include <array>
using namespace boost;
@@ -84,9 +79,11 @@ bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
bool supported = true;
// Define supported types.
- std::array<DataType,4> supportedTypes = {
+ std::array<DataType,6> supportedTypes = {
DataType::Float32,
DataType::Float16,
+ DataType::QSymmS8,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -147,10 +144,11 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
{
bool supported = true;
- std::array<DataType,5> supportedTypes = {
+ std::array<DataType,6> supportedTypes = {
DataType::Float32,
DataType::Float16,
DataType::QSymmS8,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -420,11 +418,12 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
bool supported = true;
// Define supported types.
- std::array<DataType,5> supportedTypes =
+ std::array<DataType,6> supportedTypes =
{
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS8,
DataType::QSymmS16
};
@@ -439,13 +438,14 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
"Reference Convolution2d: input and output types mismatched.");
const DataType inputType = input.GetDataType();
- if (inputType == DataType::QAsymmU8)
+ if (IsQuantized8BitType(inputType))
{
ARMNN_NO_DEPRECATE_WARN_BEGIN
- std::array<DataType, 3> supportedWeightTypes =
+ std::array<DataType, 4> supportedWeightTypes =
{
DataType::QAsymmU8,
DataType::QSymmS8,
+ DataType::QAsymmS8,
DataType::QuantizedSymm8PerAxis // deprecated
};
ARMNN_NO_DEPRECATE_WARN_END
@@ -485,11 +485,12 @@ bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
{
bool supported = true;
- std::array<DataType, 6> supportedTypes =
+ std::array<DataType, 7> supportedTypes =
{
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
+ DataType::QAsymmS8,
DataType::QSymmS8,
DataType::QSymmS16,
DataType::Signed32
@@ -545,10 +546,12 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
bool supported = true;
// Define supported types.
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,6> supportedTypes =
{
DataType::Float32,
DataType::Float16,
+ DataType::QSymmS8,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -572,7 +575,7 @@ bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
ARMNN_NO_DEPRECATE_WARN_END
const DataType inputType = input.GetDataType();
- if (inputType == DataType::QAsymmU8)
+ if (IsQuantized8BitType(inputType))
{
supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
@@ -1413,10 +1416,12 @@ bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
bool supported = true;
// Define supported output and inputs types.
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,6> supportedTypes =
{
DataType::Float32,
DataType::Float16,
+ DataType::QSymmS8,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
@@ -1476,15 +1481,17 @@ bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
ignore_unused(output);
ignore_unused(descriptor);
// Define supported output types.
- std::array<DataType,6> supportedOutputTypes =
+ std::array<DataType,7> supportedOutputTypes =
{
DataType::Float32,
DataType::Float16,
DataType::Signed32,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS8,
DataType::QSymmS16
};
+
return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
"Reference reshape: input type not supported.");
}
@@ -1586,10 +1593,12 @@ bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
{
boost::ignore_unused(descriptor);
bool supported = true;
- std::array<DataType,4> supportedTypes =
+ std::array<DataType,6> supportedTypes =
{
DataType::Float32,
DataType::Float16,
+ DataType::QSymmS8,
+ DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16
};
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 792bd7d3ad..dadb456104 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -45,17 +45,22 @@ bool IsDataType(const WorkloadInfo& info)
return false;
}
+bool IsSigned32(const WorkloadInfo& info)
+{
+ return IsDataType<DataType::Signed32>(info);
+}
+
bool IsFloat16(const WorkloadInfo& info)
{
return IsDataType<DataType::Float16>(info);
}
-bool IsQSymm16(const WorkloadInfo& info)
+bool IsQSymmS16(const WorkloadInfo& info)
{
return IsDataType<DataType::QSymmS16>(info);
}
-bool IsQSymm8(const WorkloadInfo& info)
+bool IsQSymmS8(const WorkloadInfo& info)
{
return IsDataType<DataType::QSymmS8>(info);
}
@@ -187,20 +192,20 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescr
{
return std::make_unique<RefDebugFloat16Workload>(descriptor, info);
}
- if (IsQSymm16(info))
+ if (IsQSymmS16(info))
{
- return std::make_unique<RefDebugQSymm16Workload>(descriptor, info);
+ return std::make_unique<RefDebugQSymmS16Workload>(descriptor, info);
}
- if (IsQSymm8(info))
+ if (IsQSymmS8(info))
{
- return std::make_unique<RefDebugQSymm8Workload>(descriptor, info);
+ return std::make_unique<RefDebugQSymmS8Workload>(descriptor, info);
}
- if (IsDataType<DataType::Signed32>(info))
+ if (IsSigned32(info))
{
return std::make_unique<RefDebugSigned32Workload>(descriptor, info);
}
- return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymm8Workload>(descriptor, info);
+ return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymmU8Workload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
@@ -410,7 +415,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDes
std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- if (IsQSymm16(info))
+ if (IsQSymmS16(info))
{
return std::make_unique<RefPadQSymm16Workload>(descriptor, info);
}
@@ -424,7 +429,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescripto
std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- if (IsQSymm16(info))
+ if (IsQSymmS16(info))
{
return std::make_unique<RefPermuteQSymm16Workload>(descriptor, info);
}
diff --git a/src/backends/reference/workloads/RefDebugWorkload.hpp b/src/backends/reference/workloads/RefDebugWorkload.hpp
index a15a863892..4966ca3432 100644
--- a/src/backends/reference/workloads/RefDebugWorkload.hpp
+++ b/src/backends/reference/workloads/RefDebugWorkload.hpp
@@ -37,11 +37,11 @@ private:
DebugCallbackFunction m_Callback;
};
-using RefDebugFloat16Workload = RefDebugWorkload<DataType::Float16>;
-using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>;
-using RefDebugQAsymm8Workload = RefDebugWorkload<DataType::QAsymmU8>;
-using RefDebugQSymm16Workload = RefDebugWorkload<DataType::QSymmS16>;
-using RefDebugQSymm8Workload = RefDebugWorkload<DataType::QSymmS8>;
-using RefDebugSigned32Workload = RefDebugWorkload<DataType::Signed32>;
+using RefDebugFloat16Workload = RefDebugWorkload<DataType::Float16>;
+using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>;
+using RefDebugQAsymmU8Workload = RefDebugWorkload<DataType::QAsymmU8>;
+using RefDebugQSymmS16Workload = RefDebugWorkload<DataType::QSymmS16>;
+using RefDebugQSymmS8Workload = RefDebugWorkload<DataType::QSymmS8>;
+using RefDebugSigned32Workload = RefDebugWorkload<DataType::Signed32>;
} // namespace armnn