diff options
Diffstat (limited to 'tests/ImageTensorGenerator/ImageTensorGenerator.hpp')
-rw-r--r-- | tests/ImageTensorGenerator/ImageTensorGenerator.hpp | 34 |
1 files changed, 33 insertions, 1 deletions
diff --git a/tests/ImageTensorGenerator/ImageTensorGenerator.hpp b/tests/ImageTensorGenerator/ImageTensorGenerator.hpp index 5aa2ca8124..6d2e549360 100644 --- a/tests/ImageTensorGenerator/ImageTensorGenerator.hpp +++ b/tests/ImageTensorGenerator/ImageTensorGenerator.hpp @@ -56,6 +56,10 @@ NormalizationParameters GetNormalizationParameters(const SupportedFrontend& mode normParams.mean = { 128.0, 128.0, 128.0 }; break; case armnn::DataType::QAsymmU8: + break; + case armnn::DataType::QAsymmS8: + normParams.mean = { 128.0, 128.0, 128.0 }; + break; default: break; } @@ -138,7 +142,7 @@ std::vector<int> PrepareImageTensor<int>(const std::string& imagePath, return imageDataInt; } -// Prepare qasymm8 image tensor +// Prepare qasymmu8 image tensor template <> std::vector<uint8_t> PrepareImageTensor<uint8_t>(const std::string& imagePath, unsigned int newWidth, @@ -158,6 +162,26 @@ std::vector<uint8_t> PrepareImageTensor<uint8_t>(const std::string& imagePath, return imageDataQasymm8; } +// Prepare qasymms8 image tensor +template <> +std::vector<int8_t> PrepareImageTensor<int8_t>(const std::string& imagePath, + unsigned int newWidth, + unsigned int newHeight, + const NormalizationParameters& normParams, + unsigned int batchSize, + const armnn::DataLayout& outputLayout) +{ + // Get float32 image tensor + std::vector<float> imageDataFloat = + PrepareImageTensor<float>(imagePath, newWidth, newHeight, normParams, batchSize, outputLayout); + std::vector<int8_t> imageDataQasymms8; + imageDataQasymms8.reserve(imageDataFloat.size()); + // Convert to uint8 image tensor with static cast + std::transform(imageDataFloat.begin(), imageDataFloat.end(), std::back_inserter(imageDataQasymms8), + [](float val) { return static_cast<uint8_t>(val); }); + return imageDataQasymms8; +} + /** Write image tensor to ofstream * * @param[in] imageData Image tensor data @@ -176,3 +200,11 @@ void WriteImageTensorImpl<uint8_t>(const std::vector<uint8_t>& imageData, std::o { std::copy(imageData.begin(), imageData.end(), std::ostream_iterator<int>(imageTensorFile, " ")); } + +// For int8_t image tensor, cast it to int before writing it to prevent writing data as characters instead of +// numerical values +template <> +void WriteImageTensorImpl<int8_t>(const std::vector<int8_t>& imageData, std::ofstream& imageTensorFile) +{ + std::copy(imageData.begin(), imageData.end(), std::ostream_iterator<int>(imageTensorFile, " ")); +}
\ No newline at end of file |