diff options
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index a5114ecfca..11d3542405 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -10,6 +10,7 @@ #include <armnn/TypesUtils.hpp> #include <armnn/LstmParams.hpp> #include <armnn/QuantizedLstmParams.hpp> +#include <armnn/Logging.hpp> #include <armnnUtils/Permute.hpp> #include <armnnUtils/Transpose.hpp> @@ -816,8 +817,16 @@ INetworkPtr IDeserializer::DeserializerImpl::CreateNetworkFromBinary(const std:: armnn::INetworkPtr IDeserializer::DeserializerImpl::CreateNetworkFromBinary(std::istream& binaryContent) { ResetParser(); - std::vector<uint8_t> content((std::istreambuf_iterator<char>(binaryContent)), std::istreambuf_iterator<char>()); - GraphPtr graph = LoadGraphFromBinary(content.data(), content.size()); + if (binaryContent.fail()) { + ARMNN_LOG(error) << (std::string("Cannot read input")); + throw ParseException("Unable to read Input stream data"); + } + binaryContent.seekg(0, std::ios::end); + const std::streamoff size = binaryContent.tellg(); + std::vector<char> content(static_cast<size_t>(size)); + binaryContent.seekg(0); + binaryContent.read(content.data(), static_cast<std::streamsize>(size)); + GraphPtr graph = LoadGraphFromBinary(reinterpret_cast<uint8_t*>(content.data()), static_cast<size_t>(size)); return CreateNetworkFromGraph(graph); } |