diff options
Diffstat (limited to 'src/armnnUtils/TensorUtils.cpp')
-rw-r--r-- | src/armnnUtils/TensorUtils.cpp | 91 |
1 files changed, 88 insertions, 3 deletions
diff --git a/src/armnnUtils/TensorUtils.cpp b/src/armnnUtils/TensorUtils.cpp index d77f5d74c3..9e3d719211 100644 --- a/src/armnnUtils/TensorUtils.cpp +++ b/src/armnnUtils/TensorUtils.cpp @@ -128,12 +128,11 @@ TensorShape ExpandDims(const TensorShape& tensorShape, int axis) } outputShape.insert(outputShape.begin() + axis, 1); - return TensorShape(outputDim, outputShape.data()); + return { outputDim, outputShape.data() }; } std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape) { - unsigned int outputDimSize = 0; std::vector<unsigned int> squeezedDims; for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i) @@ -141,7 +140,6 @@ std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape) if (tensorShape[i] != 1) { squeezedDims.push_back(tensorShape[i]); - ++outputDimSize; } } return squeezedDims; @@ -201,4 +199,91 @@ std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::Tensor return { axisFactor, scales }; } +template<typename PrimitiveType> +void CheckSizes(const std::vector<PrimitiveType>& data, const armnn::TensorInfo& tensorInfo, unsigned int size = 1) +{ + if (data.size() / size != tensorInfo.GetNumElements()) + { + throw InvalidArgumentException( + fmt::format("The data does not contain the expected number of elements {} != {}. {}", + data.size(), tensorInfo.GetNumElements(), CHECK_LOCATION().AsString())); + } +} + +template<typename PrimitiveType> +std::unique_ptr<float[]> ToFloatArray(const std::vector<PrimitiveType>& data, const armnn::TensorInfo& tensorInfo) +{ + CheckSizes(data, tensorInfo); + + std::unique_ptr<float[]> returnBuffer(new float[tensorInfo.GetNumElements()]); + + if (tensorInfo.HasPerAxisQuantization()) + { + unsigned int axis = tensorInfo.GetQuantizationDim().value(); + auto axisDimensionality = tensorInfo.GetShape()[axis]; + auto axisFactor = armnnUtils::GetNumElementsAfter(tensorInfo.GetShape(), axis); + + for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i) + { + unsigned int axisIndex; + + if (i < axisFactor) + { + axisIndex = 0; + } + else + { + axisIndex = (i / axisFactor) % axisDimensionality; + } + returnBuffer[i] = Dequantize<PrimitiveType>(data[i], + tensorInfo.GetQuantizationScales()[axisIndex], + tensorInfo.GetQuantizationOffset()); + } + } + else + { + for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i) + { + returnBuffer[i] = Dequantize<PrimitiveType>(data[i], + tensorInfo.GetQuantizationScale(), + tensorInfo.GetQuantizationOffset()); + } + } + return returnBuffer; +} + +std::unique_ptr<float[]> ToFloatArray(const std::vector<uint8_t>& data, const armnn::TensorInfo& tensorInfo) +{ + if (tensorInfo.GetDataType() == DataType::QAsymmS8 || tensorInfo.GetDataType() == DataType::QSymmS8) + { + CheckSizes(data, tensorInfo); + std::vector<int8_t> buffer(tensorInfo.GetNumElements()); + ::memcpy(buffer.data(), data.data(), data.size()); + return ToFloatArray<int8_t>(buffer, tensorInfo); + } + else if (tensorInfo.GetDataType() == DataType::QAsymmU8) + { + CheckSizes(data, tensorInfo); + return ToFloatArray<uint8_t>(data, tensorInfo); + } + else if (tensorInfo.GetDataType() == DataType::Signed32) + { + CheckSizes(data, tensorInfo, 4); + std::vector<int32_t> buffer(tensorInfo.GetNumElements()); + ::memcpy(buffer.data(), data.data(), data.size()); + return ToFloatArray<int32_t>(buffer, tensorInfo); + } + else if (tensorInfo.GetDataType() == DataType::Signed64) + { + CheckSizes(data, tensorInfo, 8); + std::vector<int64_t> buffer(tensorInfo.GetNumElements()); + ::memcpy(buffer.data(), data.data(), data.size()); + return ToFloatArray<int64_t>(buffer, tensorInfo); + } + throw InvalidArgumentException( + fmt::format("Unsupported datatype {}. {}", + GetDataTypeName(tensorInfo.GetDataType()), + CHECK_LOCATION().AsString())); +} + } // namespace armnnUtils |