diff options
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index 9a4864542f..16d34c8c9d 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -89,10 +89,17 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, int input = tfLiteInterpreter->inputs()[inputIndex]; TfLiteIntArray* inputDims = tfLiteInterpreter->tensor(input)->dims; - long inputSize = 1; - for (unsigned int dim = 0; dim < static_cast<unsigned int>(inputDims->size); ++dim) + unsigned int inputSize = 1; + if (params.m_InputTensorShapes.size() > 0) { - inputSize *= inputDims->data[dim]; + inputSize = params.m_InputTensorShapes[inputIndex]->GetNumElements(); + } + else + { + for (unsigned int dim = 0; dim < static_cast<unsigned int>(inputDims->size); ++dim) + { + inputSize *= inputDims->data[dim]; + } } if (params.m_InputTypes[inputIndex].compare("float") == 0) @@ -108,10 +115,10 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, std::vector<float> tensorData; PopulateTensorWithDataGeneric<float>(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), - dataFile, - [](const std::string& s) - { return std::stof(s); }); + inputSize, + dataFile, + [](const std::string& s) + { return std::stof(s); }); std::copy(tensorData.begin(), tensorData.end(), inputData); } @@ -128,7 +135,7 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, std::vector<int8_t> tensorData; PopulateTensorWithDataGeneric<int8_t>(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), + inputSize, dataFile, [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); }); @@ -148,7 +155,7 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, std::vector<int32_t> tensorData; PopulateTensorWithDataGeneric<int32_t>(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), + inputSize, dataFile, [](const std::string& s) { return std::stoi(s); }); @@ -169,7 +176,7 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, std::vector<uint8_t> tensorData; PopulateTensorWithDataGeneric<uint8_t>(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), + inputSize, dataFile, [](const std::string& s) { return armnn::numeric_cast<uint8_t>(std::stoi(s)); }); @@ -189,7 +196,7 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, std::vector<int8_t> tensorData; PopulateTensorWithDataGeneric<int8_t>(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), + inputSize, dataFile, [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); }); |