From c577f2c6a3b4ddb6ba87a882723c53a248afbeba Mon Sep 17 00:00:00 2001 From: telsoa01 Date: Fri, 31 Aug 2018 09:22:23 +0100 Subject: Release 18.08 --- tests/CMakeLists.txt | 97 +- tests/CaffeAlexNet-Armnn/CaffeAlexNet-Armnn.cpp | 13 +- .../CaffeCifar10AcrossChannels-Armnn.cpp | 11 +- .../CaffeInception_BN-Armnn.cpp | 13 +- tests/CaffeMnist-Armnn/CaffeMnist-Armnn.cpp | 11 +- tests/CaffePreprocessor.cpp | 47 + tests/CaffePreprocessor.hpp | 40 + tests/CaffeResNet-Armnn/CaffeResNet-Armnn.cpp | 14 +- .../CaffeSqueezeNet1_0-Armnn.cpp | 6 +- tests/CaffeVGG-Armnn/CaffeVGG-Armnn.cpp | 14 +- tests/CaffeYolo-Armnn/CaffeYolo-Armnn.cpp | 1 + tests/Cifar10Database.hpp | 3 +- tests/ExecuteNetwork/ExecuteNetwork.cpp | 518 ++++++++-- tests/ImageNetDatabase.cpp | 47 - tests/ImageNetDatabase.hpp | 37 - tests/ImagePreprocessor.cpp | 74 ++ tests/ImagePreprocessor.hpp | 73 ++ tests/InferenceModel.hpp | 270 ++++-- tests/InferenceTest.cpp | 23 +- tests/InferenceTest.hpp | 44 +- tests/InferenceTest.inl | 54 +- tests/InferenceTestImage.cpp | 158 ++- tests/InferenceTestImage.hpp | 25 +- tests/MnistDatabase.cpp | 8 +- tests/MnistDatabase.hpp | 3 +- tests/MobileNetDatabase.cpp | 133 --- tests/MobileNetDatabase.hpp | 36 - .../MultipleNetworksCifar10.cpp | 30 +- tests/OnnxMnist-Armnn/OnnxMnist-Armnn.cpp | 39 + tests/OnnxMnist-Armnn/Validation.txt | 1000 +++++++++++++++++++ tests/OnnxMobileNet-Armnn/OnnxMobileNet-Armnn.cpp | 60 ++ tests/OnnxMobileNet-Armnn/Validation.txt | 201 ++++ tests/OnnxMobileNet-Armnn/labels.txt | 1001 ++++++++++++++++++++ tests/TfCifar10-Armnn/TfCifar10-Armnn.cpp | 12 +- tests/TfInceptionV3-Armnn/TfInceptionV3-Armnn.cpp | 13 +- .../TfLiteMobilenetQuantized-Armnn.cpp | 84 ++ .../TfLiteMobilenetQuantized-Armnn/Validation.txt | 201 ++++ tests/TfLiteMobilenetQuantized-Armnn/labels.txt | 1001 ++++++++++++++++++++ tests/TfMnist-Armnn/TfMnist-Armnn.cpp | 11 +- tests/TfMobileNet-Armnn/TfMobileNet-Armnn.cpp | 26 +- .../TfResNext_Quantized-Armnn.cpp | 13 +- tests/YoloDatabase.cpp | 8 +- tests/YoloInferenceTest.hpp | 12 +- 43 files changed, 4946 insertions(+), 539 deletions(-) create mode 100644 tests/CaffePreprocessor.cpp create mode 100644 tests/CaffePreprocessor.hpp delete mode 100644 tests/ImageNetDatabase.cpp delete mode 100644 tests/ImageNetDatabase.hpp create mode 100644 tests/ImagePreprocessor.cpp create mode 100644 tests/ImagePreprocessor.hpp delete mode 100644 tests/MobileNetDatabase.cpp delete mode 100644 tests/MobileNetDatabase.hpp create mode 100644 tests/OnnxMnist-Armnn/OnnxMnist-Armnn.cpp create mode 100644 tests/OnnxMnist-Armnn/Validation.txt create mode 100644 tests/OnnxMobileNet-Armnn/OnnxMobileNet-Armnn.cpp create mode 100644 tests/OnnxMobileNet-Armnn/Validation.txt create mode 100644 tests/OnnxMobileNet-Armnn/labels.txt create mode 100644 tests/TfLiteMobilenetQuantized-Armnn/TfLiteMobilenetQuantized-Armnn.cpp create mode 100644 tests/TfLiteMobilenetQuantized-Armnn/Validation.txt create mode 100644 tests/TfLiteMobilenetQuantized-Armnn/labels.txt (limited to 'tests') diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ecdff7f909..0979d552de 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -49,8 +49,8 @@ if(BUILD_CAFFE_PARSER) set(CaffeAlexNet-Armnn_sources CaffeAlexNet-Armnn/CaffeAlexNet-Armnn.cpp - ImageNetDatabase.hpp - ImageNetDatabase.cpp) + CaffePreprocessor.hpp + CaffePreprocessor.cpp) CaffeParserTest(CaffeAlexNet-Armnn "${CaffeAlexNet-Armnn_sources}") set(MultipleNetworksCifar10_SRC @@ -61,20 +61,20 @@ if(BUILD_CAFFE_PARSER) set(CaffeResNet-Armnn_sources CaffeResNet-Armnn/CaffeResNet-Armnn.cpp - ImageNetDatabase.hpp - ImageNetDatabase.cpp) + CaffePreprocessor.hpp + CaffePreprocessor.cpp) CaffeParserTest(CaffeResNet-Armnn "${CaffeResNet-Armnn_sources}") set(CaffeVGG-Armnn_sources CaffeVGG-Armnn/CaffeVGG-Armnn.cpp - ImageNetDatabase.hpp - ImageNetDatabase.cpp) + CaffePreprocessor.hpp + CaffePreprocessor.cpp) CaffeParserTest(CaffeVGG-Armnn "${CaffeVGG-Armnn_sources}") set(CaffeInception_BN-Armnn_sources CaffeInception_BN-Armnn/CaffeInception_BN-Armnn.cpp - ImageNetDatabase.hpp - ImageNetDatabase.cpp) + CaffePreprocessor.hpp + CaffePreprocessor.cpp) CaffeParserTest(CaffeInception_BN-Armnn "${CaffeInception_BN-Armnn_sources}") set(CaffeYolo-Armnn_sources @@ -118,29 +118,88 @@ if(BUILD_TF_PARSER) set(TfMobileNet-Armnn_sources TfMobileNet-Armnn/TfMobileNet-Armnn.cpp - MobileNetDatabase.hpp - MobileNetDatabase.cpp) + ImagePreprocessor.hpp + ImagePreprocessor.cpp) TfParserTest(TfMobileNet-Armnn "${TfMobileNet-Armnn_sources}") set(TfInceptionV3-Armnn_sources TfInceptionV3-Armnn/TfInceptionV3-Armnn.cpp - MobileNetDatabase.hpp - MobileNetDatabase.cpp) + ImagePreprocessor.hpp + ImagePreprocessor.cpp) TfParserTest(TfInceptionV3-Armnn "${TfInceptionV3-Armnn_sources}") set(TfResNext-Armnn_sources TfResNext_Quantized-Armnn/TfResNext_Quantized-Armnn.cpp - ImageNetDatabase.hpp - ImageNetDatabase.cpp) + CaffePreprocessor.hpp + CaffePreprocessor.cpp) TfParserTest(TfResNext-Armnn "${TfResNext-Armnn_sources}") endif() -if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER) +if (BUILD_TF_LITE_PARSER) + macro(TfLiteParserTest testName sources) + add_executable_ex(${testName} ${sources}) + target_include_directories(${testName} PRIVATE ../src/armnnUtils) + + target_link_libraries(${testName} inferenceTest) + target_link_libraries(${testName} armnnTfLiteParser) + target_link_libraries(${testName} armnn) + target_link_libraries(${testName} ${CMAKE_THREAD_LIBS_INIT}) + if(OPENCL_LIBRARIES) + target_link_libraries(${testName} ${OPENCL_LIBRARIES}) + endif() + target_link_libraries(${testName} + ${Boost_SYSTEM_LIBRARY} + ${Boost_FILESYSTEM_LIBRARY} + ${Boost_PROGRAM_OPTIONS_LIBRARY}) + addDllCopyCommands(${testName}) + endmacro() + + set(TfLiteMobilenetQuantized-Armnn_sources + TfLiteMobilenetQuantized-Armnn/TfLiteMobilenetQuantized-Armnn.cpp + ImagePreprocessor.hpp + ImagePreprocessor.cpp) + TfLiteParserTest(TfLiteMobilenetQuantized-Armnn "${TfLiteMobilenetQuantized-Armnn_sources}") +endif() + +if (BUILD_ONNX_PARSER) + macro(OnnxParserTest testName sources) + add_executable_ex(${testName} ${sources}) + target_include_directories(${testName} PRIVATE ../src/armnnUtils) + + target_link_libraries(${testName} inferenceTest) + target_link_libraries(${testName} armnnOnnxParser) + target_link_libraries(${testName} armnn) + target_link_libraries(${testName} ${CMAKE_THREAD_LIBS_INIT}) + if(OPENCL_LIBRARIES) + target_link_libraries(${testName} ${OPENCL_LIBRARIES}) + endif() + target_link_libraries(${testName} + ${Boost_SYSTEM_LIBRARY} + ${Boost_FILESYSTEM_LIBRARY} + ${Boost_PROGRAM_OPTIONS_LIBRARY}) + addDllCopyCommands(${testName}) + endmacro() + + set(OnnxMnist-Armnn_sources + OnnxMnist-Armnn/OnnxMnist-Armnn.cpp + MnistDatabase.hpp + MnistDatabase.cpp) + OnnxParserTest(OnnxMnist-Armnn "${OnnxMnist-Armnn_sources}") + + set(OnnxMobileNet-Armnn_sources + OnnxMobileNet-Armnn/OnnxMobileNet-Armnn.cpp + ImagePreprocessor.hpp + ImagePreprocessor.cpp) + OnnxParserTest(OnnxMobileNet-Armnn "${OnnxMobileNet-Armnn_sources}") +endif() + +if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_PARSER) set(ExecuteNetwork_sources ExecuteNetwork/ExecuteNetwork.cpp) add_executable_ex(ExecuteNetwork ${ExecuteNetwork_sources}) target_include_directories(ExecuteNetwork PRIVATE ../src/armnnUtils) + target_include_directories(ExecuteNetwork PRIVATE ../src/armnn) if (BUILD_CAFFE_PARSER) target_link_libraries(ExecuteNetwork armnnCaffeParser) @@ -148,6 +207,14 @@ if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER) if (BUILD_TF_PARSER) target_link_libraries(ExecuteNetwork armnnTfParser) endif() + + if (BUILD_TF_LITE_PARSER) + target_link_libraries(ExecuteNetwork armnnTfLiteParser) + endif() + if (BUILD_ONNX_PARSER) + target_link_libraries(ExecuteNetwork armnnOnnxParser) + endif() + target_link_libraries(ExecuteNetwork armnn) target_link_libraries(ExecuteNetwork ${CMAKE_THREAD_LIBS_INIT}) if(OPENCL_LIBRARIES) diff --git a/tests/CaffeAlexNet-Armnn/CaffeAlexNet-Armnn.cpp b/tests/CaffeAlexNet-Armnn/CaffeAlexNet-Armnn.cpp index dce4e08d05..b7ec4f63f1 100644 --- a/tests/CaffeAlexNet-Armnn/CaffeAlexNet-Armnn.cpp +++ b/tests/CaffeAlexNet-Armnn/CaffeAlexNet-Armnn.cpp @@ -3,7 +3,7 @@ // See LICENSE file in the project root for full license information. // #include "../InferenceTest.hpp" -#include "../ImageNetDatabase.hpp" +#include "../CaffePreprocessor.hpp" #include "armnnCaffeParser/ICaffeParser.hpp" int main(int argc, char* argv[]) @@ -11,10 +11,17 @@ int main(int argc, char* argv[]) int retVal = EXIT_FAILURE; try { + using DataType = float; + using DatabaseType = CaffePreprocessor; + using ParserType = armnnCaffeParser::ICaffeParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "bvlc_alexnet_1.caffemodel", true, "data", "prob", { 0 }, - [](const char* dataDir) { return ImageNetDatabase(dataDir); }); + [](const char* dataDir, const ModelType &) { + return DatabaseType(dataDir); + }); } catch (const std::exception& e) { diff --git a/tests/CaffeCifar10AcrossChannels-Armnn/CaffeCifar10AcrossChannels-Armnn.cpp b/tests/CaffeCifar10AcrossChannels-Armnn/CaffeCifar10AcrossChannels-Armnn.cpp index fbd3312f04..ff6e93ff7c 100644 --- a/tests/CaffeCifar10AcrossChannels-Armnn/CaffeCifar10AcrossChannels-Armnn.cpp +++ b/tests/CaffeCifar10AcrossChannels-Armnn/CaffeCifar10AcrossChannels-Armnn.cpp @@ -11,11 +11,18 @@ int main(int argc, char* argv[]) int retVal = EXIT_FAILURE; try { + using DataType = float; + using DatabaseType = Cifar10Database; + using ParserType = armnnCaffeParser::ICaffeParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "cifar10_full_iter_60000.caffemodel", true, "data", "prob", { 0, 1, 2, 4, 7 }, - [](const char* dataDir) { return Cifar10Database(dataDir); }); + [](const char* dataDir, const ModelType&) { + return DatabaseType(dataDir); + }); } catch (const std::exception& e) { diff --git a/tests/CaffeInception_BN-Armnn/CaffeInception_BN-Armnn.cpp b/tests/CaffeInception_BN-Armnn/CaffeInception_BN-Armnn.cpp index a6581bea55..fccf9aff70 100644 --- a/tests/CaffeInception_BN-Armnn/CaffeInception_BN-Armnn.cpp +++ b/tests/CaffeInception_BN-Armnn/CaffeInception_BN-Armnn.cpp @@ -3,7 +3,7 @@ // See LICENSE file in the project root for full license information. // #include "../InferenceTest.hpp" -#include "../ImageNetDatabase.hpp" +#include "../CaffePreprocessor.hpp" #include "armnnCaffeParser/ICaffeParser.hpp" int main(int argc, char* argv[]) @@ -17,11 +17,18 @@ int main(int argc, char* argv[]) {"shark.jpg", 3694} }; + using DataType = float; + using DatabaseType = CaffePreprocessor; + using ParserType = armnnCaffeParser::ICaffeParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "Inception-BN-batchsize1.caffemodel", true, "data", "softmax", { 0 }, - [&imageSet](const char* dataDir) { return ImageNetDatabase(dataDir, 224, 224, imageSet); }); + [&imageSet](const char* dataDir, const ModelType&) { + return DatabaseType(dataDir, 224, 224, imageSet); + }); } catch (const std::exception& e) { diff --git a/tests/CaffeMnist-Armnn/CaffeMnist-Armnn.cpp b/tests/CaffeMnist-Armnn/CaffeMnist-Armnn.cpp index ec14a5d7bc..644041bb5f 100644 --- a/tests/CaffeMnist-Armnn/CaffeMnist-Armnn.cpp +++ b/tests/CaffeMnist-Armnn/CaffeMnist-Armnn.cpp @@ -11,11 +11,18 @@ int main(int argc, char* argv[]) int retVal = EXIT_FAILURE; try { + using DataType = float; + using DatabaseType = MnistDatabase; + using ParserType = armnnCaffeParser::ICaffeParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "lenet_iter_9000.caffemodel", true, "data", "prob", { 0, 1, 5, 8, 9 }, - [](const char* dataDir) { return MnistDatabase(dataDir); }); + [](const char* dataDir, const ModelType&) { + return DatabaseType(dataDir); + }); } catch (const std::exception& e) { diff --git a/tests/CaffePreprocessor.cpp b/tests/CaffePreprocessor.cpp new file mode 100644 index 0000000000..226e57ab17 --- /dev/null +++ b/tests/CaffePreprocessor.cpp @@ -0,0 +1,47 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#include "InferenceTestImage.hpp" +#include "CaffePreprocessor.hpp" + +#include +#include +#include +#include + +#include +#include +#include + +const std::vector g_DefaultImageSet = +{ + {"shark.jpg", 2} +}; + +CaffePreprocessor::CaffePreprocessor(const std::string& binaryFileDirectory, unsigned int width, unsigned int height, + const std::vector& imageSet) +: m_BinaryDirectory(binaryFileDirectory) +, m_Height(height) +, m_Width(width) +, m_ImageSet(imageSet.empty() ? g_DefaultImageSet : imageSet) +{ +} + +std::unique_ptr CaffePreprocessor::GetTestCaseData(unsigned int testCaseId) +{ + testCaseId = testCaseId % boost::numeric_cast(m_ImageSet.size()); + const ImageSet& imageSet = m_ImageSet[testCaseId]; + const std::string fullPath = m_BinaryDirectory + imageSet.first; + + InferenceTestImage image(fullPath.c_str()); + image.Resize(m_Width, m_Height, CHECK_LOCATION()); + + // The model expects image data in BGR format. + std::vector inputImageData = GetImageDataInArmNnLayoutAsFloatsSubtractingMean(ImageChannelLayout::Bgr, + image, m_MeanBgr); + + // List of labels: https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a + const unsigned int label = imageSet.second; + return std::make_unique(label, std::move(inputImageData)); +} diff --git a/tests/CaffePreprocessor.hpp b/tests/CaffePreprocessor.hpp new file mode 100644 index 0000000000..90eebf97b7 --- /dev/null +++ b/tests/CaffePreprocessor.hpp @@ -0,0 +1,40 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "ClassifierTestCaseData.hpp" + +#include +#include +#include +#include + +/// Caffe requires BGR images, not normalized, mean adjusted and resized using smooth resize of STB library + +using ImageSet = std::pair; + +class CaffePreprocessor +{ +public: + using DataType = float; + using TTestCaseData = ClassifierTestCaseData; + + explicit CaffePreprocessor(const std::string& binaryFileDirectory, + unsigned int width = 227, + unsigned int height = 227, + const std::vector& imageSet = std::vector()); + std::unique_ptr GetTestCaseData(unsigned int testCaseId); + +private: + unsigned int GetNumImageElements() const { return 3 * m_Width * m_Height; } + unsigned int GetNumImageBytes() const { return 4 * GetNumImageElements(); } + + std::string m_BinaryDirectory; + unsigned int m_Height; + unsigned int m_Width; + // Mean value of the database [B, G, R]. + const std::array m_MeanBgr = {{104.007965f, 116.669472f, 122.675102f}}; + const std::vector m_ImageSet; +}; diff --git a/tests/CaffeResNet-Armnn/CaffeResNet-Armnn.cpp b/tests/CaffeResNet-Armnn/CaffeResNet-Armnn.cpp index 7cccb215a1..3b1a2945a5 100644 --- a/tests/CaffeResNet-Armnn/CaffeResNet-Armnn.cpp +++ b/tests/CaffeResNet-Armnn/CaffeResNet-Armnn.cpp @@ -3,7 +3,7 @@ // See LICENSE file in the project root for full license information. // #include "../InferenceTest.hpp" -#include "../ImageNetDatabase.hpp" +#include "../CaffePreprocessor.hpp" #include "armnnCaffeParser/ICaffeParser.hpp" int main(int argc, char* argv[]) @@ -20,12 +20,18 @@ int main(int argc, char* argv[]) armnn::TensorShape inputTensorShape({ 1, 3, 224, 224 }); + using DataType = float; + using DatabaseType = CaffePreprocessor; + using ParserType = armnnCaffeParser::ICaffeParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "ResNet_50_ilsvrc15_model.caffemodel", true, "data", "prob", { 0, 1 }, - [&imageSet](const char* dataDir) { return ImageNetDatabase(dataDir, 224, 224, imageSet); }, - &inputTensorShape); + [&imageSet](const char* dataDir, const ModelType&) { + return DatabaseType(dataDir, 224, 224, imageSet); + }, &inputTensorShape); } catch (const std::exception& e) { diff --git a/tests/CaffeSqueezeNet1_0-Armnn/CaffeSqueezeNet1_0-Armnn.cpp b/tests/CaffeSqueezeNet1_0-Armnn/CaffeSqueezeNet1_0-Armnn.cpp index f0b48836f1..1ca8429bd2 100644 --- a/tests/CaffeSqueezeNet1_0-Armnn/CaffeSqueezeNet1_0-Armnn.cpp +++ b/tests/CaffeSqueezeNet1_0-Armnn/CaffeSqueezeNet1_0-Armnn.cpp @@ -3,13 +3,13 @@ // See LICENSE file in the project root for full license information. // #include "../InferenceTest.hpp" -#include "../ImageNetDatabase.hpp" +#include "../CaffePreprocessor.hpp" #include "armnnCaffeParser/ICaffeParser.hpp" int main(int argc, char* argv[]) { - return armnn::test::ClassifierInferenceTestMain( + return armnn::test::ClassifierInferenceTestMain( argc, argv, "squeezenet.caffemodel", true, "data", "output", { 0 }, - [](const char* dataDir) { return ImageNetDatabase(dataDir); }); + [](const char* dataDir) { return CaffePreprocessor(dataDir); }); } diff --git a/tests/CaffeVGG-Armnn/CaffeVGG-Armnn.cpp b/tests/CaffeVGG-Armnn/CaffeVGG-Armnn.cpp index b859042935..99ced3dc43 100644 --- a/tests/CaffeVGG-Armnn/CaffeVGG-Armnn.cpp +++ b/tests/CaffeVGG-Armnn/CaffeVGG-Armnn.cpp @@ -3,7 +3,7 @@ // See LICENSE file in the project root for full license information. // #include "../InferenceTest.hpp" -#include "../ImageNetDatabase.hpp" +#include "../CaffePreprocessor.hpp" #include "armnnCaffeParser/ICaffeParser.hpp" int main(int argc, char* argv[]) @@ -12,12 +12,18 @@ int main(int argc, char* argv[]) int retVal = EXIT_FAILURE; try { + using DataType = float; + using DatabaseType = CaffePreprocessor; + using ParserType = armnnCaffeParser::ICaffeParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "VGG_CNN_S.caffemodel", true, "input", "prob", { 0 }, - [](const char* dataDir) { return ImageNetDatabase(dataDir, 224, 224); }, - &inputTensorShape); + [](const char* dataDir, const ModelType&) { + return DatabaseType(dataDir, 224, 224); + }, &inputTensorShape); } catch (const std::exception& e) { diff --git a/tests/CaffeYolo-Armnn/CaffeYolo-Armnn.cpp b/tests/CaffeYolo-Armnn/CaffeYolo-Armnn.cpp index ad79d49f0c..7396b7672c 100644 --- a/tests/CaffeYolo-Armnn/CaffeYolo-Armnn.cpp +++ b/tests/CaffeYolo-Armnn/CaffeYolo-Armnn.cpp @@ -37,6 +37,7 @@ int main(int argc, char* argv[]) modelParams.m_IsModelBinary = true; modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice; modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel; + modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode; return std::make_unique(modelParams); }); diff --git a/tests/Cifar10Database.hpp b/tests/Cifar10Database.hpp index a4998cee1d..1a819aad64 100644 --- a/tests/Cifar10Database.hpp +++ b/tests/Cifar10Database.hpp @@ -12,7 +12,8 @@ class Cifar10Database { public: - using TTestCaseData = ClassifierTestCaseData; + using DataType = float; + using TTestCaseData = ClassifierTestCaseData; explicit Cifar10Database(const std::string& binaryFileDirectory, bool rgbPack = false); std::unique_ptr GetTestCaseData(unsigned int testCaseId); diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index 74737e2718..fdec15a61d 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -3,30 +3,50 @@ // See LICENSE file in the project root for full license information. // #include "armnn/ArmNN.hpp" + +#include + #if defined(ARMNN_CAFFE_PARSER) #include "armnnCaffeParser/ICaffeParser.hpp" #endif #if defined(ARMNN_TF_PARSER) #include "armnnTfParser/ITfParser.hpp" #endif -#include "Logging.hpp" +#if defined(ARMNN_TF_LITE_PARSER) +#include "armnnTfLiteParser/ITfLiteParser.hpp" +#endif +#if defined(ARMNN_ONNX_PARSER) +#include "armnnOnnxParser/IOnnxParser.hpp" +#endif +#include "CsvReader.hpp" #include "../InferenceTest.hpp" -#include +#include +#include + +#include #include #include +#include #include #include +#include +#include +#include +#include namespace { +// Configure boost::program_options for command-line parsing and validation. +namespace po = boost::program_options; + template std::vector ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc) { std::vector result; - // Process line-by-line + // Processes line-by-line. std::string line; while (std::getline(stream, line)) { @@ -60,6 +80,46 @@ std::vector ParseArrayImpl(std::istream& stream, TParseElementFunc parseEleme return result; } +bool CheckOption(const po::variables_map& vm, + const char* option) +{ + // Check that the given option is valid. + if (option == nullptr) + { + return false; + } + + // Check whether 'option' is provided. + return vm.find(option) != vm.end(); +} + +void CheckOptionDependency(const po::variables_map& vm, + const char* option, + const char* required) +{ + // Check that the given options are valid. + if (option == nullptr || required == nullptr) + { + throw po::error("Invalid option to check dependency for"); + } + + // Check that if 'option' is provided, 'required' is also provided. + if (CheckOption(vm, option) && !vm[option].defaulted()) + { + if (CheckOption(vm, required) == 0 || vm[required].defaulted()) + { + throw po::error(std::string("Option '") + option + "' requires option '" + required + "'."); + } + } +} + +void CheckOptionDependencies(const po::variables_map& vm) +{ + CheckOptionDependency(vm, "model-path", "model-format"); + CheckOptionDependency(vm, "model-path", "input-name"); + CheckOptionDependency(vm, "model-path", "input-tensor-data"); + CheckOptionDependency(vm, "model-path", "output-name"); + CheckOptionDependency(vm, "input-tensor-shape", "model-path"); } template @@ -87,26 +147,61 @@ void PrintArray(const std::vector& v) printf("\n"); } +void RemoveDuplicateDevices(std::vector& computeDevices) +{ + // Mark the duplicate devices as 'Undefined'. + for (auto i = computeDevices.begin(); i != computeDevices.end(); ++i) + { + for (auto j = std::next(i); j != computeDevices.end(); ++j) + { + if (*j == *i) + { + *j = armnn::Compute::Undefined; + } + } + } + + // Remove 'Undefined' devices. + computeDevices.erase(std::remove(computeDevices.begin(), computeDevices.end(), armnn::Compute::Undefined), + computeDevices.end()); +} + +bool CheckDevicesAreValid(const std::vector& computeDevices) +{ + return (!computeDevices.empty() + && std::none_of(computeDevices.begin(), computeDevices.end(), + [](armnn::Compute c){ return c == armnn::Compute::Undefined; })); +} + +} // namespace + template -int MainImpl(const char* modelPath, bool isModelBinary, armnn::Compute computeDevice, - const char* inputName, const armnn::TensorShape* inputTensorShape, const char* inputTensorDataFilePath, - const char* outputName) +int MainImpl(const char* modelPath, + bool isModelBinary, + const std::vector& computeDevice, + const char* inputName, + const armnn::TensorShape* inputTensorShape, + const char* inputTensorDataFilePath, + const char* outputName, + bool enableProfiling, + const size_t subgraphId, + const std::shared_ptr& runtime = nullptr) { - // Load input tensor + // Loads input tensor. std::vector input; { std::ifstream inputTensorFile(inputTensorDataFilePath); if (!inputTensorFile.good()) { BOOST_LOG_TRIVIAL(fatal) << "Failed to load input tensor data file from " << inputTensorDataFilePath; - return 1; + return EXIT_FAILURE; } input = ParseArray(inputTensorFile); } try { - // Create an InferenceModel, which will parse the model and load it into an IRuntime + // Creates an InferenceModel, which will parse the model and load it into an IRuntime. typename InferenceModel::Params params; params.m_ModelPath = modelPath; params.m_IsModelBinary = isModelBinary; @@ -114,27 +209,235 @@ int MainImpl(const char* modelPath, bool isModelBinary, armnn::Compute computeDe params.m_InputBinding = inputName; params.m_InputTensorShape = inputTensorShape; params.m_OutputBinding = outputName; - InferenceModel model(params); + params.m_EnableProfiling = enableProfiling; + params.m_SubgraphId = subgraphId; + InferenceModel model(params, runtime); - // Execute the model + // Executes the model. std::vector output(model.GetOutputSize()); model.Run(input, output); - // Print the output tensor + // Prints the output tensor. PrintArray(output); } catch (armnn::Exception const& e) { BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what(); - return 1; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +// This will run a test +int RunTest(const std::string& modelFormat, + const std::string& inputTensorShapeStr, + const vector& computeDevice, + const std::string& modelPath, + const std::string& inputName, + const std::string& inputTensorDataFilePath, + const std::string& outputName, + bool enableProfiling, + const size_t subgraphId, + const std::shared_ptr& runtime = nullptr) +{ + // Parse model binary flag from the model-format string we got from the command-line + bool isModelBinary; + if (modelFormat.find("bin") != std::string::npos) + { + isModelBinary = true; + } + else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos) + { + isModelBinary = false; + } + else + { + BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'"; + return EXIT_FAILURE; } - return 0; + // Parse input tensor shape from the string we got from the command-line. + std::unique_ptr inputTensorShape; + if (!inputTensorShapeStr.empty()) + { + std::stringstream ss(inputTensorShapeStr); + std::vector dims = ParseArray(ss); + + try + { + // Coverity fix: An exception of type armnn::InvalidArgumentException is thrown and never caught. + inputTensorShape = std::make_unique(dims.size(), dims.data()); + } + catch (const armnn::InvalidArgumentException& e) + { + BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what(); + return EXIT_FAILURE; + } + } + + // Forward to implementation based on the parser type + if (modelFormat.find("caffe") != std::string::npos) + { +#if defined(ARMNN_CAFFE_PARSER) + return MainImpl(modelPath.c_str(), isModelBinary, computeDevice, + inputName.c_str(), inputTensorShape.get(), + inputTensorDataFilePath.c_str(), outputName.c_str(), + enableProfiling, subgraphId, runtime); +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support."; + return EXIT_FAILURE; +#endif + } + else if (modelFormat.find("onnx") != std::string::npos) +{ +#if defined(ARMNN_ONNX_PARSER) + return MainImpl(modelPath.c_str(), isModelBinary, computeDevice, + inputName.c_str(), inputTensorShape.get(), + inputTensorDataFilePath.c_str(), outputName.c_str(), + enableProfiling, subgraphId, runtime); +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support."; + return EXIT_FAILURE; +#endif + } + else if (modelFormat.find("tensorflow") != std::string::npos) + { +#if defined(ARMNN_TF_PARSER) + return MainImpl(modelPath.c_str(), isModelBinary, computeDevice, + inputName.c_str(), inputTensorShape.get(), + inputTensorDataFilePath.c_str(), outputName.c_str(), + enableProfiling, subgraphId, runtime); +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support."; + return EXIT_FAILURE; +#endif + } + else if(modelFormat.find("tflite") != std::string::npos) + { +#if defined(ARMNN_TF_LITE_PARSER) + if (! isModelBinary) + { + BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \ + for tflite files"; + return EXIT_FAILURE; + } + return MainImpl(modelPath.c_str(), isModelBinary, computeDevice, + inputName.c_str(), inputTensorShape.get(), + inputTensorDataFilePath.c_str(), outputName.c_str(), + enableProfiling, subgraphId, runtime); +#else + BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << + "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'"; + return EXIT_FAILURE; +#endif + } + else + { + BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << + "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'"; + return EXIT_FAILURE; + } } -int main(int argc, char* argv[]) +int RunCsvTest(const armnnUtils::CsvRow &csvRow, + const std::shared_ptr& runtime) { - // Configure logging for both the ARMNN library and this test program + std::string modelFormat; + std::string modelPath; + std::string inputName; + std::string inputTensorShapeStr; + std::string inputTensorDataFilePath; + std::string outputName; + + size_t subgraphId = 0; + + po::options_description desc("Options"); + try + { + desc.add_options() + ("model-format,f", po::value(&modelFormat), + "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.") + ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite," + " .onnx") + ("compute,c", po::value>()->multitoken(), + "The preferred order of devices to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc") + ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.") + ("subgraph-number,n", po::value(&subgraphId)->default_value(0), "Id of the subgraph to be " + "executed. Defaults to 0") + ("input-tensor-shape,s", po::value(&inputTensorShapeStr), + "The shape of the input tensor in the network as a flat array of integers separated by whitespace. " + "This parameter is optional, depending on the network.") + ("input-tensor-data,d", po::value(&inputTensorDataFilePath), + "Path to a file containing the input data as a flat array separated by whitespace.") + ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.") + ("event-based-profiling,e", po::bool_switch()->default_value(false), + "Enables built in profiler. If unset, defaults to off."); + } + catch (const std::exception& e) + { + // Coverity points out that default_value(...) can throw a bad_lexical_cast, + // and that desc.add_options() can throw boost::io::too_few_args. + // They really won't in any of these cases. + BOOST_ASSERT_MSG(false, "Caught unexpected exception"); + BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what(); + return EXIT_FAILURE; + } + + std::vector clOptions; + clOptions.reserve(csvRow.values.size()); + for (const std::string& value : csvRow.values) + { + clOptions.push_back(value.c_str()); + } + + po::variables_map vm; + try + { + po::store(po::parse_command_line(static_cast(clOptions.size()), clOptions.data(), desc), vm); + + po::notify(vm); + + CheckOptionDependencies(vm); + } + catch (const po::error& e) + { + std::cerr << e.what() << std::endl << std::endl; + std::cerr << desc << std::endl; + return EXIT_FAILURE; + } + + // Remove leading and trailing whitespaces from the parsed arguments. + boost::trim(modelFormat); + boost::trim(modelPath); + boost::trim(inputName); + boost::trim(inputTensorShapeStr); + boost::trim(inputTensorDataFilePath); + boost::trim(outputName); + + // Get the value of the switch arguments. + bool enableProfiling = vm["event-based-profiling"].as(); + + // Get the preferred order of compute devices. + std::vector computeDevices = vm["compute"].as>(); + + // Remove duplicates from the list of compute devices. + RemoveDuplicateDevices(computeDevices); + + // Check that the specified compute devices are valid. + if (!CheckDevicesAreValid(computeDevices)) + { + BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains an invalid compute"; + return EXIT_FAILURE; + } + + return RunTest(modelFormat, inputTensorShapeStr, computeDevices, + modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId, runtime); +} + +int main(int argc, const char* argv[]) +{ + // Configures logging for both the ARMNN library and this test program. #ifdef NDEBUG armnn::LogSeverity level = armnn::LogSeverity::Info; #else @@ -143,8 +446,7 @@ int main(int argc, char* argv[]) armnn::ConfigureLogging(true, true, level); armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level); - // Configure boost::program_options for command-line parsing - namespace po = boost::program_options; + std::string testCasesFile; std::string modelFormat; std::string modelPath; @@ -152,25 +454,36 @@ int main(int argc, char* argv[]) std::string inputTensorShapeStr; std::string inputTensorDataFilePath; std::string outputName; - armnn::Compute computeDevice; + + size_t subgraphId = 0; po::options_description desc("Options"); try { desc.add_options() ("help", "Display usage information") - ("model-format,f", po::value(&modelFormat)->required(), - "caffe-binary, caffe-text, tensorflow-binary or tensorflow-text.") - ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt") - ("compute,c", po::value(&computeDevice)->required(), - "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc") - ("input-name,i", po::value(&inputName)->required(), "Identifier of the input tensor in the network.") + ("test-cases,t", po::value(&testCasesFile), "Path to a CSV file containing test cases to run. " + "If set, further parameters -- with the exception of compute device and concurrency -- will be ignored, " + "as they are expected to be defined in the file for each test in particular.") + ("concurrent,n", po::bool_switch()->default_value(false), + "Whether or not the test cases should be executed in parallel") + ("model-format,f", po::value(&modelFormat), + "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.") + ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt," + " .tflite, .onnx") + ("compute,c", po::value>()->multitoken(), + "The preferred order of devices to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc") + ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.") + ("subgraph-number,x", po::value(&subgraphId)->default_value(0), "Id of the subgraph to be executed." + "Defaults to 0") ("input-tensor-shape,s", po::value(&inputTensorShapeStr), - "The shape of the input tensor in the network as a flat array of integers separated by whitespace. " - "This parameter is optional, depending on the network.") - ("input-tensor-data,d", po::value(&inputTensorDataFilePath)->required(), + "The shape of the input tensor in the network as a flat array of integers separated by whitespace. " + "This parameter is optional, depending on the network.") + ("input-tensor-data,d", po::value(&inputTensorDataFilePath), "Path to a file containing the input data as a flat array separated by whitespace.") - ("output-name,o", po::value(&outputName)->required(), "Identifier of the output tensor in the network."); + ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.") + ("event-based-profiling,e", po::bool_switch()->default_value(false), + "Enables built in profiler. If unset, defaults to off."); } catch (const std::exception& e) { @@ -179,93 +492,128 @@ int main(int argc, char* argv[]) // They really won't in any of these cases. BOOST_ASSERT_MSG(false, "Caught unexpected exception"); BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what(); - return 1; + return EXIT_FAILURE; } - // Parse the command-line + // Parses the command-line. po::variables_map vm; try { po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help") || argc <= 1) + if (CheckOption(vm, "help") || argc <= 1) { std::cout << "Executes a neural network model using the provided input tensor. " << std::endl; std::cout << "Prints the resulting output tensor." << std::endl; std::cout << std::endl; std::cout << desc << std::endl; - return 1; + return EXIT_SUCCESS; } po::notify(vm); } - catch (po::error& e) + catch (const po::error& e) { std::cerr << e.what() << std::endl << std::endl; std::cerr << desc << std::endl; - return 1; + return EXIT_FAILURE; } - // Parse model binary flag from the model-format string we got from the command-line - bool isModelBinary; - if (modelFormat.find("bin") != std::string::npos) - { - isModelBinary = true; - } - else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos) + // Get the value of the switch arguments. + bool concurrent = vm["concurrent"].as(); + bool enableProfiling = vm["event-based-profiling"].as(); + + // Check whether we have to load test cases from a file. + if (CheckOption(vm, "test-cases")) { - isModelBinary = false; + // Check that the file exists. + if (!boost::filesystem::exists(testCasesFile)) + { + BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" does not exist"; + return EXIT_FAILURE; + } + + // Parse CSV file and extract test cases + armnnUtils::CsvReader reader; + std::vector testCases = reader.ParseFile(testCasesFile); + + // Check that there is at least one test case to run + if (testCases.empty()) + { + BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" has no test cases"; + return EXIT_FAILURE; + } + + // Create runtime + armnn::IRuntime::CreationOptions options; + std::shared_ptr runtime(armnn::IRuntime::Create(options)); + + const std::string executableName("ExecuteNetwork"); + + // Check whether we need to run the test cases concurrently + if (concurrent) + { + std::vector> results; + results.reserve(testCases.size()); + + // Run each test case in its own thread + for (auto& testCase : testCases) + { + testCase.values.insert(testCase.values.begin(), executableName); + results.push_back(std::async(std::launch::async, RunCsvTest, std::cref(testCase), std::cref(runtime))); + } + + // Check results + for (auto& result : results) + { + if (result.get() != EXIT_SUCCESS) + { + return EXIT_FAILURE; + } + } + } + else + { + // Run tests sequentially + for (auto& testCase : testCases) + { + testCase.values.insert(testCase.values.begin(), executableName); + if (RunCsvTest(testCase, runtime) != EXIT_SUCCESS) + { + return EXIT_FAILURE; + } + } + } + + return EXIT_SUCCESS; } - else + else // Run single test { - BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'"; - return 1; - } + // Get the preferred order of compute devices. + std::vector computeDevices = vm["compute"].as>(); - // Parse input tensor shape from the string we got from the command-line. - std::unique_ptr inputTensorShape; - if (!inputTensorShapeStr.empty()) - { - std::stringstream ss(inputTensorShapeStr); - std::vector dims = ParseArray(ss); + // Remove duplicates from the list of compute devices. + RemoveDuplicateDevices(computeDevices); + + // Check that the specified compute devices are valid. + if (!CheckDevicesAreValid(computeDevices)) + { + BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains an invalid compute"; + return EXIT_FAILURE; + } try { - // Coverity fix: An exception of type armnn::InvalidArgumentException is thrown and never caught. - inputTensorShape = std::make_unique(dims.size(), dims.data()); + CheckOptionDependencies(vm); } - catch (const armnn::InvalidArgumentException& e) + catch (const po::error& e) { - BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what(); - return 1; + std::cerr << e.what() << std::endl << std::endl; + std::cerr << desc << std::endl; + return EXIT_FAILURE; } - } - // Forward to implementation based on the parser type - if (modelFormat.find("caffe") != std::string::npos) - { -#if defined(ARMNN_CAFFE_PARSER) - return MainImpl(modelPath.c_str(), isModelBinary, computeDevice, - inputName.c_str(), inputTensorShape.get(), inputTensorDataFilePath.c_str(), outputName.c_str()); -#else - BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support."; - return 1; -#endif - } - else if (modelFormat.find("tensorflow") != std::string::npos) - { -#if defined(ARMNN_TF_PARSER) - return MainImpl(modelPath.c_str(), isModelBinary, computeDevice, - inputName.c_str(), inputTensorShape.get(), inputTensorDataFilePath.c_str(), outputName.c_str()); -#else - BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support."; - return 1; -#endif - } - else - { - BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << - "'. Please include 'caffe' or 'tensorflow'"; - return 1; + return RunTest(modelFormat, inputTensorShapeStr, computeDevices, + modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId); } } diff --git a/tests/ImageNetDatabase.cpp b/tests/ImageNetDatabase.cpp deleted file mode 100644 index ac4bc21ff9..0000000000 --- a/tests/ImageNetDatabase.cpp +++ /dev/null @@ -1,47 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// See LICENSE file in the project root for full license information. -// -#include "InferenceTestImage.hpp" -#include "ImageNetDatabase.hpp" - -#include -#include -#include -#include - -#include -#include -#include - -const std::vector g_DefaultImageSet = -{ - {"shark.jpg", 2} -}; - -ImageNetDatabase::ImageNetDatabase(const std::string& binaryFileDirectory, unsigned int width, unsigned int height, - const std::vector& imageSet) -: m_BinaryDirectory(binaryFileDirectory) -, m_Height(height) -, m_Width(width) -, m_ImageSet(imageSet.empty() ? g_DefaultImageSet : imageSet) -{ -} - -std::unique_ptr ImageNetDatabase::GetTestCaseData(unsigned int testCaseId) -{ - testCaseId = testCaseId % boost::numeric_cast(m_ImageSet.size()); - const ImageSet& imageSet = m_ImageSet[testCaseId]; - const std::string fullPath = m_BinaryDirectory + imageSet.first; - - InferenceTestImage image(fullPath.c_str()); - image.Resize(m_Width, m_Height); - - // The model expects image data in BGR format - std::vector inputImageData = GetImageDataInArmNnLayoutAsFloatsSubtractingMean(ImageChannelLayout::Bgr, - image, m_MeanBgr); - - // list of labels: https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a - const unsigned int label = imageSet.second; - return std::make_unique(label, std::move(inputImageData)); -} diff --git a/tests/ImageNetDatabase.hpp b/tests/ImageNetDatabase.hpp deleted file mode 100644 index cd990c458a..0000000000 --- a/tests/ImageNetDatabase.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// See LICENSE file in the project root for full license information. -// -#pragma once - -#include "ClassifierTestCaseData.hpp" - -#include -#include -#include -#include - -using ImageSet = std::pair; - -class ImageNetDatabase -{ -public: - using TTestCaseData = ClassifierTestCaseData; - - explicit ImageNetDatabase(const std::string& binaryFileDirectory, - unsigned int width = 227, - unsigned int height = 227, - const std::vector& imageSet = std::vector()); - std::unique_ptr GetTestCaseData(unsigned int testCaseId); - -private: - unsigned int GetNumImageElements() const { return 3 * m_Width * m_Height; } - unsigned int GetNumImageBytes() const { return 4 * GetNumImageElements(); } - - std::string m_BinaryDirectory; - unsigned int m_Height; - unsigned int m_Width; - //mean value of the database [B, G, R] - const std::array m_MeanBgr = {{104.007965f, 116.669472f, 122.675102f}}; - const std::vector m_ImageSet; -}; \ No newline at end of file diff --git a/tests/ImagePreprocessor.cpp b/tests/ImagePreprocessor.cpp new file mode 100644 index 0000000000..4e46b914ae --- /dev/null +++ b/tests/ImagePreprocessor.cpp @@ -0,0 +1,74 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#include "InferenceTestImage.hpp" +#include "ImagePreprocessor.hpp" +#include "Permute.hpp" +#include + +#include +#include +#include + +#include +#include +#include + +template +unsigned int ImagePreprocessor::GetLabelAndResizedImageAsFloat(unsigned int testCaseId, + std::vector & result) +{ + testCaseId = testCaseId % boost::numeric_cast(m_ImageSet.size()); + const ImageSet& imageSet = m_ImageSet[testCaseId]; + const std::string fullPath = m_BinaryDirectory + imageSet.first; + + InferenceTestImage image(fullPath.c_str()); + + // this ResizeBilinear result is closer to the tensorflow one than STB. + // there is still some difference though, but the inference results are + // similar to tensorflow for MobileNet + + result = image.Resize(m_Width, m_Height, CHECK_LOCATION(), + InferenceTestImage::ResizingMethods::BilinearAndNormalized, + m_Mean, m_Stddev); + + if (m_DataFormat == DataFormat::NCHW) + { + const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 }; + armnn::TensorShape dstShape({1, 3, m_Height, m_Width}); + std::vector tempImage(result.size()); + armnnUtils::Permute(dstShape, NHWCToArmNN, result.data(), tempImage.data()); + result.swap(tempImage); + } + + return imageSet.second; +} + +template <> +std::unique_ptr::TTestCaseData> +ImagePreprocessor::GetTestCaseData(unsigned int testCaseId) +{ + std::vector resized; + auto label = GetLabelAndResizedImageAsFloat(testCaseId, resized); + return std::make_unique(label, std::move(resized)); +} + +template <> +std::unique_ptr::TTestCaseData> +ImagePreprocessor::GetTestCaseData(unsigned int testCaseId) +{ + std::vector resized; + auto label = GetLabelAndResizedImageAsFloat(testCaseId, resized); + + size_t resizedSize = resized.size(); + std::vector quantized(resized.size()); + + for (size_t i=0; i(resized[i], + m_Scale, + m_Offset); + } + return std::make_unique(label, std::move(quantized)); +} diff --git a/tests/ImagePreprocessor.hpp b/tests/ImagePreprocessor.hpp new file mode 100644 index 0000000000..b8a473d92c --- /dev/null +++ b/tests/ImagePreprocessor.hpp @@ -0,0 +1,73 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "ClassifierTestCaseData.hpp" + +#include +#include +#include +#include + +///Tf requires RGB images, normalized in range [0, 1] and resized using Bilinear algorithm + + +using ImageSet = std::pair; + +template +class ImagePreprocessor +{ +public: + using DataType = TDataType; + using TTestCaseData = ClassifierTestCaseData; + + enum DataFormat + { + NHWC, + NCHW + }; + + explicit ImagePreprocessor(const std::string& binaryFileDirectory, + unsigned int width, + unsigned int height, + const std::vector& imageSet, + float scale=1.0, + int32_t offset=0, + const std::array mean={{0, 0, 0}}, + const std::array stddev={{1, 1, 1}}, + DataFormat dataFormat=DataFormat::NHWC) + : m_BinaryDirectory(binaryFileDirectory) + , m_Height(height) + , m_Width(width) + , m_Scale(scale) + , m_Offset(offset) + , m_ImageSet(imageSet) + , m_Mean(mean) + , m_Stddev(stddev) + , m_DataFormat(dataFormat) + { + } + + std::unique_ptr GetTestCaseData(unsigned int testCaseId); + +private: + unsigned int GetNumImageElements() const { return 3 * m_Width * m_Height; } + unsigned int GetNumImageBytes() const { return sizeof(DataType) * GetNumImageElements(); } + unsigned int GetLabelAndResizedImageAsFloat(unsigned int testCaseId, + std::vector & result); + + std::string m_BinaryDirectory; + unsigned int m_Height; + unsigned int m_Width; + // Quantization parameters + float m_Scale; + int32_t m_Offset; + const std::vector m_ImageSet; + + const std::array m_Mean; + const std::array m_Stddev; + + DataFormat m_DataFormat; +}; diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index f5f00378ca..dfd21bbed1 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -4,7 +4,15 @@ // #pragma once #include "armnn/ArmNN.hpp" -#include "HeapProfiling.hpp" + +#if defined(ARMNN_TF_LITE_PARSER) +#include "armnnTfLiteParser/ITfLiteParser.hpp" +#endif + +#include +#if defined(ARMNN_ONNX_PARSER) +#include "armnnOnnxParser/IOnnxParser.hpp" +#endif #include #include @@ -16,9 +24,148 @@ #include #include #include +#include + +namespace InferenceModelInternal +{ +// This needs to go when the armnnCaffeParser, armnnTfParser and armnnTfLiteParser +// definitions of BindingPointInfo gets consolidated. +using BindingPointInfo = std::pair; + +using QuantizationParams = std::pair; + +struct Params +{ + std::string m_ModelPath; + std::string m_InputBinding; + std::string m_OutputBinding; + const armnn::TensorShape* m_InputTensorShape; + std::vector m_ComputeDevice; + bool m_EnableProfiling; + size_t m_SubgraphId; + bool m_IsModelBinary; + bool m_VisualizePostOptimizationModel; + bool m_EnableFp16TurboMode; + + Params() + : m_InputTensorShape(nullptr) + , m_ComputeDevice{armnn::Compute::CpuRef} + , m_EnableProfiling(false) + , m_SubgraphId(0) + , m_IsModelBinary(true) + , m_VisualizePostOptimizationModel(false) + , m_EnableFp16TurboMode(false) + {} +}; + +} // namespace InferenceModelInternal + +template +struct CreateNetworkImpl +{ +public: + using Params = InferenceModelInternal::Params; + using BindingPointInfo = InferenceModelInternal::BindingPointInfo; + + static armnn::INetworkPtr Create(const Params& params, + BindingPointInfo& inputBindings, + BindingPointInfo& outputBindings) + { + const std::string& modelPath = params.m_ModelPath; + + // Create a network from a file on disk + auto parser(IParser::Create()); + + std::map inputShapes; + if (params.m_InputTensorShape) + { + inputShapes[params.m_InputBinding] = *params.m_InputTensorShape; + } + std::vector requestedOutputs{ params.m_OutputBinding }; + armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; + + { + ARMNN_SCOPED_HEAP_PROFILING("Parsing"); + // Handle text and binary input differently by calling the corresponding parser function + network = (params.m_IsModelBinary ? + parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes, requestedOutputs) : + parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes, requestedOutputs)); + } + + inputBindings = parser->GetNetworkInputBindingInfo(params.m_InputBinding); + outputBindings = parser->GetNetworkOutputBindingInfo(params.m_OutputBinding); + return network; + } +}; + +#if defined(ARMNN_TF_LITE_PARSER) +template <> +struct CreateNetworkImpl +{ +public: + using IParser = armnnTfLiteParser::ITfLiteParser; + using Params = InferenceModelInternal::Params; + using BindingPointInfo = InferenceModelInternal::BindingPointInfo; + + static armnn::INetworkPtr Create(const Params& params, + BindingPointInfo& inputBindings, + BindingPointInfo& outputBindings) + { + const std::string& modelPath = params.m_ModelPath; + + // Create a network from a file on disk + auto parser(IParser::Create()); + + armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; + + { + ARMNN_SCOPED_HEAP_PROFILING("Parsing"); + network = parser->CreateNetworkFromBinaryFile(modelPath.c_str()); + } + + inputBindings = parser->GetNetworkInputBindingInfo(params.m_SubgraphId, params.m_InputBinding); + outputBindings = parser->GetNetworkOutputBindingInfo(params.m_SubgraphId, params.m_OutputBinding); + return network; + } +}; +#endif + +#if defined(ARMNN_ONNX_PARSER) +template <> +struct CreateNetworkImpl +{ +public: + using IParser = armnnOnnxParser::IOnnxParser; + using Params = InferenceModelInternal::Params; + using BindingPointInfo = InferenceModelInternal::BindingPointInfo; + + static armnn::INetworkPtr Create(const Params& params, + BindingPointInfo& inputBindings, + BindingPointInfo& outputBindings) + { + const std::string& modelPath = params.m_ModelPath; + + // Create a network from a file on disk + auto parser(IParser::Create()); + + armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; + + { + ARMNN_SCOPED_HEAP_PROFILING("Parsing"); + network = (params.m_IsModelBinary ? + parser->CreateNetworkFromBinaryFile(modelPath.c_str()) : + parser->CreateNetworkFromTextFile(modelPath.c_str())); + } + + inputBindings = parser->GetNetworkInputBindingInfo(params.m_InputBinding); + outputBindings = parser->GetNetworkOutputBindingInfo(params.m_OutputBinding); + return network; + } +}; +#endif template -inline armnn::InputTensors MakeInputTensors(const std::pair& input, +inline armnn::InputTensors MakeInputTensors(const InferenceModelInternal::BindingPointInfo& input, const TContainer& inputTensorData) { if (inputTensorData.size() != input.second.GetNumElements()) @@ -30,7 +177,7 @@ inline armnn::InputTensors MakeInputTensors(const std::pair -inline armnn::OutputTensors MakeOutputTensors(const std::pair& output, +inline armnn::OutputTensors MakeOutputTensors(const InferenceModelInternal::BindingPointInfo& output, TContainer& outputTensorData) { if (outputTensorData.size() != output.second.GetNumElements()) @@ -48,17 +195,21 @@ inline armnn::OutputTensors MakeOutputTensors(const std::pair class InferenceModel { public: using DataType = TDataType; + using Params = InferenceModelInternal::Params; struct CommandLineOptions { std::string m_ModelDir; - armnn::Compute m_ComputeDevice; + std::vector m_ComputeDevice; bool m_VisualizePostOptimizationModel; + bool m_EnableFp16TurboMode; }; static void AddCommandLineOptions(boost::program_options::options_description& desc, CommandLineOptions& options) @@ -67,66 +218,47 @@ public: desc.add_options() ("model-dir,m", po::value(&options.m_ModelDir)->required(), - "Path to directory containing model files (.caffemodel/.prototxt)") - ("compute,c", po::value(&options.m_ComputeDevice)->default_value(armnn::Compute::CpuAcc), + "Path to directory containing model files (.caffemodel/.prototxt/.tflite)") + ("compute,c", po::value>(&options.m_ComputeDevice)->default_value + ({armnn::Compute::CpuAcc, armnn::Compute::CpuRef}), "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc") ("visualize-optimized-model,v", po::value(&options.m_VisualizePostOptimizationModel)->default_value(false), "Produce a dot file useful for visualizing the graph post optimization." - "The file will have the same name as the model with the .dot extention."); + "The file will have the same name as the model with the .dot extention.") + ("fp16-turbo-mode", po::value(&options.m_EnableFp16TurboMode)->default_value(false), + "If this option is enabled FP32 layers, weights and biases will be converted " + "to FP16 where the backend supports it."); } - struct Params + InferenceModel(const Params& params, const std::shared_ptr& runtime = nullptr) + : m_EnableProfiling(params.m_EnableProfiling) { - std::string m_ModelPath; - std::string m_InputBinding; - std::string m_OutputBinding; - const armnn::TensorShape* m_InputTensorShape; - armnn::Compute m_ComputeDevice; - bool m_IsModelBinary; - bool m_VisualizePostOptimizationModel; - - Params() - : m_InputTensorShape(nullptr) - , m_ComputeDevice(armnn::Compute::CpuRef) - , m_IsModelBinary(true) - , m_VisualizePostOptimizationModel(false) + if (runtime) { + m_Runtime = runtime; } - }; - - - InferenceModel(const Params& params) - : m_Runtime(armnn::IRuntime::Create(params.m_ComputeDevice)) - { - const std::string& modelPath = params.m_ModelPath; - - // Create a network from a file on disk - auto parser(IParser::Create()); - - std::map inputShapes; - if (params.m_InputTensorShape) + else { - inputShapes[params.m_InputBinding] = *params.m_InputTensorShape; + armnn::IRuntime::CreationOptions options; + m_Runtime = std::move(armnn::IRuntime::Create(options)); } - std::vector requestedOutputs{ params.m_OutputBinding }; - armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; - { - ARMNN_SCOPED_HEAP_PROFILING("Parsing"); - // Handle text and binary input differently by calling the corresponding parser function - network = (params.m_IsModelBinary ? - parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes, requestedOutputs) : - parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes, requestedOutputs)); - } - - m_InputBindingInfo = parser->GetNetworkInputBindingInfo(params.m_InputBinding); - m_OutputBindingInfo = parser->GetNetworkOutputBindingInfo(params.m_OutputBinding); + armnn::INetworkPtr network = CreateNetworkImpl::Create(params, m_InputBindingInfo, + m_OutputBindingInfo); armnn::IOptimizedNetworkPtr optNet{nullptr, [](armnn::IOptimizedNetwork *){}}; { ARMNN_SCOPED_HEAP_PROFILING("Optimizing"); - optNet = armnn::Optimize(*network, m_Runtime->GetDeviceSpec()); + + armnn::OptimizerOptions options; + options.m_ReduceFp32ToFp16 = params.m_EnableFp16TurboMode; + + optNet = armnn::Optimize(*network, params.m_ComputeDevice, m_Runtime->GetDeviceSpec(), options); + if (!optNet) + { + throw armnn::Exception("Optimize returned nullptr"); + } } if (params.m_VisualizePostOptimizationModel) @@ -157,16 +289,46 @@ public: void Run(const std::vector& input, std::vector& output) { BOOST_ASSERT(output.size() == GetOutputSize()); + + std::shared_ptr profiler = m_Runtime->GetProfiler(m_NetworkIdentifier); + if (profiler) + { + profiler->EnableProfiling(m_EnableProfiling); + } + armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier, - MakeInputTensors(input), - MakeOutputTensors(output)); + MakeInputTensors(input), + MakeOutputTensors(output)); if (ret == armnn::Status::Failure) { throw armnn::Exception("IRuntime::EnqueueWorkload failed"); } } + const InferenceModelInternal::BindingPointInfo & GetInputBindingInfo() const + { + return m_InputBindingInfo; + } + + const InferenceModelInternal::BindingPointInfo & GetOutputBindingInfo() const + { + return m_OutputBindingInfo; + } + + InferenceModelInternal::QuantizationParams GetQuantizationParams() const + { + return std::make_pair(m_OutputBindingInfo.second.GetQuantizationScale(), + m_OutputBindingInfo.second.GetQuantizationOffset()); + } + private: + armnn::NetworkId m_NetworkIdentifier; + std::shared_ptr m_Runtime; + + InferenceModelInternal::BindingPointInfo m_InputBindingInfo; + InferenceModelInternal::BindingPointInfo m_OutputBindingInfo; + bool m_EnableProfiling; + template armnn::InputTensors MakeInputTensors(const TContainer& inputTensorData) { @@ -178,10 +340,4 @@ private: { return ::MakeOutputTensors(m_OutputBindingInfo, outputTensorData); } - - armnn::NetworkId m_NetworkIdentifier; - armnn::IRuntimePtr m_Runtime; - - std::pair m_InputBindingInfo; - std::pair m_OutputBindingInfo; }; diff --git a/tests/InferenceTest.cpp b/tests/InferenceTest.cpp index 161481f2cd..477ae4e84e 100644 --- a/tests/InferenceTest.cpp +++ b/tests/InferenceTest.cpp @@ -4,6 +4,7 @@ // #include "InferenceTest.hpp" +#include "../src/armnn/Profiling.hpp" #include #include #include @@ -26,7 +27,6 @@ namespace armnn { namespace test { - /// Parse the command line of an ArmNN (or referencetests) inference test program. /// \return false if any error occurred during options processing, otherwise true bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider, @@ -40,15 +40,17 @@ bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCas try { - // Add generic options needed for all inference tests + // Adds generic options needed for all inference tests. desc.add_options() ("help", "Display help messages") ("iterations,i", po::value(&outParams.m_IterationCount)->default_value(0), "Sets the number number of inferences to perform. If unset, a default number will be ran.") ("inference-times-file", po::value(&outParams.m_InferenceTimesFile)->default_value(""), - "If non-empty, each individual inference time will be recorded and output to this file"); + "If non-empty, each individual inference time will be recorded and output to this file") + ("event-based-profiling,e", po::value(&outParams.m_EnableProfiling)->default_value(0), + "Enables built in profiler. If unset, defaults to off."); - // Add options specific to the ITestCaseProvider + // Adds options specific to the ITestCaseProvider. testCaseProvider.AddCommandLineOptions(desc); } catch (const std::exception& e) @@ -111,7 +113,7 @@ bool InferenceTest(const InferenceTestOptions& params, IInferenceTestCaseProvider& testCaseProvider) { #if !defined (NDEBUG) - if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn + if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn. { BOOST_LOG_TRIVIAL(warning) << "Performance test running in DEBUG build - results may be inaccurate."; } @@ -121,7 +123,7 @@ bool InferenceTest(const InferenceTestOptions& params, unsigned int nbProcessed = 0; bool success = true; - // Open the file to write inference times to, if needed + // Opens the file to write inference times too, if needed. ofstream inferenceTimesFile; const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty(); if (recordInferenceTimes) @@ -135,6 +137,13 @@ bool InferenceTest(const InferenceTestOptions& params, } } + // Create a profiler and register it for the current thread. + std::unique_ptr profiler = std::make_unique(); + ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); + + // Enable profiling if requested. + profiler->EnableProfiling(params.m_EnableProfiling); + // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer std::unique_ptr warmupTestCase = testCaseProvider.GetTestCase(0); if (warmupTestCase == nullptr) @@ -184,7 +193,7 @@ bool InferenceTest(const InferenceTestOptions& params, double timeTakenS = duration(predictEnd - predictStart).count(); totalTime += timeTakenS; - // Output inference times if needed + // Outputss inference times, if needed. if (recordInferenceTimes) { inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl; diff --git a/tests/InferenceTest.hpp b/tests/InferenceTest.hpp index 5f53c06a88..181afe4d8f 100644 --- a/tests/InferenceTest.hpp +++ b/tests/InferenceTest.hpp @@ -6,11 +6,14 @@ #include "armnn/ArmNN.hpp" #include "armnn/TypesUtils.hpp" +#include "InferenceModel.hpp" + #include #include #include + namespace armnn { @@ -40,9 +43,11 @@ struct InferenceTestOptions { unsigned int m_IterationCount; std::string m_InferenceTimesFile; + bool m_EnableProfiling; InferenceTestOptions() - : m_IterationCount(0) + : m_IterationCount(0), + m_EnableProfiling(0) {} }; @@ -108,6 +113,31 @@ private: std::vector m_Output; }; +template +struct ToFloat { }; // nothing defined for the generic case + +template <> +struct ToFloat +{ + static inline float Convert(float value, const InferenceModelInternal::QuantizationParams &) + { + // assuming that float models are not quantized + return value; + } +}; + +template <> +struct ToFloat +{ + static inline float Convert(uint8_t value, + const InferenceModelInternal::QuantizationParams & quantizationParams) + { + return armnn::Dequantize(value, + quantizationParams.first, + quantizationParams.second); + } +}; + template class ClassifierTestCase : public InferenceModelTestCase { @@ -125,6 +155,8 @@ public: private: unsigned int m_Label; + InferenceModelInternal::QuantizationParams m_QuantizationParams; + /// These fields reference the corresponding member in the ClassifierTestCaseProvider. /// @{ int& m_NumInferencesRef; @@ -154,17 +186,17 @@ private: std::unique_ptr m_Model; std::string m_DataDir; - std::function m_ConstructDatabase; + std::function m_ConstructDatabase; std::unique_ptr m_Database; - int m_NumInferences; // Referenced by test cases - int m_NumCorrectInferences; // Referenced by test cases + int m_NumInferences; // Referenced by test cases. + int m_NumCorrectInferences; // Referenced by test cases. std::string m_ValidationFileIn; - std::vector m_ValidationPredictions; // Referenced by test cases + std::vector m_ValidationPredictions; // Referenced by test cases. std::string m_ValidationFileOut; - std::vector m_ValidationPredictionsOut; // Referenced by test cases + std::vector m_ValidationPredictionsOut; // Referenced by test cases. }; bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider, diff --git a/tests/InferenceTest.inl b/tests/InferenceTest.inl index a36e231e76..16df7bace3 100644 --- a/tests/InferenceTest.inl +++ b/tests/InferenceTest.inl @@ -4,8 +4,6 @@ // #include "InferenceTest.hpp" -#include "InferenceModel.hpp" - #include #include #include @@ -30,6 +28,7 @@ namespace armnn namespace test { + template ClassifierTestCase::ClassifierTestCase( int& numInferencesRef, @@ -42,6 +41,7 @@ ClassifierTestCase::ClassifierTestCase( std::vector modelInput) : InferenceModelTestCase(model, testCaseId, std::move(modelInput), model.GetOutputSize()) , m_Label(label) + , m_QuantizationParams(model.GetQuantizationParams()) , m_NumInferencesRef(numInferencesRef) , m_NumCorrectInferencesRef(numCorrectInferencesRef) , m_ValidationPredictions(validationPredictions) @@ -60,7 +60,7 @@ TestCaseResult ClassifierTestCase::ProcessResult(cons int index = 0; for (const auto & o : output) { - resultMap[o] = index++; + resultMap[ToFloat::Convert(o, m_QuantizationParams)] = index++; } } @@ -78,7 +78,7 @@ TestCaseResult ClassifierTestCase::ProcessResult(cons const unsigned int prediction = boost::numeric_cast( std::distance(output.begin(), std::max_element(output.begin(), output.end()))); - // If we're just running the defaultTestCaseIds, each one must be classified correctly + // If we're just running the defaultTestCaseIds, each one must be classified correctly. if (params.m_IterationCount == 0 && prediction != m_Label) { BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" << @@ -86,7 +86,7 @@ TestCaseResult ClassifierTestCase::ProcessResult(cons return TestCaseResult::Failed; } - // If a validation file was provided as input, check that the prediction matches + // If a validation file was provided as input, it checks that the prediction matches. if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId]) { BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" << @@ -94,13 +94,13 @@ TestCaseResult ClassifierTestCase::ProcessResult(cons return TestCaseResult::Failed; } - // If a validation file was requested as output, store the predictions + // If a validation file was requested as output, it stores the predictions. if (m_ValidationPredictionsOut) { m_ValidationPredictionsOut->push_back(prediction); } - // Update accuracy stats + // Updates accuracy stats. m_NumInferencesRef++; if (prediction == m_Label) { @@ -154,7 +154,7 @@ bool ClassifierTestCaseProvider::ProcessCommandLineOp return false; } - m_Database = std::make_unique(m_ConstructDatabase(m_DataDir.c_str())); + m_Database = std::make_unique(m_ConstructDatabase(m_DataDir.c_str(), *m_Model)); if (!m_Database) { return false; @@ -191,7 +191,7 @@ bool ClassifierTestCaseProvider::OnInferenceTestFinis boost::numeric_cast(m_NumInferences); BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy; - // If a validation file was requested as output, save the predictions to it + // If a validation file was requested as output, the predictions are saved to it. if (!m_ValidationFileOut.empty()) { std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out); @@ -215,7 +215,7 @@ bool ClassifierTestCaseProvider::OnInferenceTestFinis template void ClassifierTestCaseProvider::ReadPredictions() { - // Read expected predictions from the input validation file (if provided) + // Reads the expected predictions from the input validation file (if provided). if (!m_ValidationFileIn.empty()) { std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in); @@ -242,7 +242,7 @@ int InferenceTestMain(int argc, const std::vector& defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider) { - // Configure logging for both the ARMNN library and this test program + // Configures logging for both the ARMNN library and this test program. #ifdef NDEBUG armnn::LogSeverity level = armnn::LogSeverity::Info; #else @@ -275,20 +275,35 @@ int InferenceTestMain(int argc, } } +// +// This function allows us to create a classifier inference test based on: +// - a model file name +// - which can be a binary or a text file for protobuf formats +// - an input tensor name +// - an output tensor name +// - a set of test case ids +// - a callback method which creates an object that can return images +// called 'Database' in these tests +// - and an input tensor shape +// template -int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary, - const char* inputBindingName, const char* outputBindingName, - const std::vector& defaultTestCaseIds, - TConstructDatabaseCallable constructDatabase, - const armnn::TensorShape* inputTensorShape) + typename TParser, + typename TConstructDatabaseCallable> +int ClassifierInferenceTestMain(int argc, + char* argv[], + const char* modelFilename, + bool isModelBinary, + const char* inputBindingName, + const char* outputBindingName, + const std::vector& defaultTestCaseIds, + TConstructDatabaseCallable constructDatabase, + const armnn::TensorShape* inputTensorShape) { return InferenceTestMain(argc, argv, defaultTestCaseIds, [=] () { - using InferenceModel = InferenceModel; + using InferenceModel = InferenceModel; using TestCaseProvider = ClassifierTestCaseProvider; return make_unique(constructDatabase, @@ -308,6 +323,7 @@ int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilenam modelParams.m_IsModelBinary = isModelBinary; modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice; modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel; + modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode; return std::make_unique(modelParams); }); diff --git a/tests/InferenceTestImage.cpp b/tests/InferenceTestImage.cpp index 205460a2f2..cc85adcf3f 100644 --- a/tests/InferenceTestImage.cpp +++ b/tests/InferenceTestImage.cpp @@ -37,6 +37,90 @@ unsigned int GetImageChannelIndex(ImageChannelLayout channelLayout, ImageChannel } } +inline float Lerp(float a, float b, float w) +{ + return w * b + (1.f - w) * a; +} + +inline void PutData(std::vector & data, + const unsigned int width, + const unsigned int x, + const unsigned int y, + const unsigned int c, + float value) +{ + data[(3*((y*width)+x)) + c] = value; +} + +std::vector ResizeBilinearAndNormalize(const InferenceTestImage & image, + const unsigned int outputWidth, + const unsigned int outputHeight, + const std::array& mean, + const std::array& stddev) +{ + std::vector out; + out.resize(outputWidth * outputHeight * 3); + + // We follow the definition of TensorFlow and AndroidNN: the top-left corner of a texel in the output + // image is projected into the input image to figure out the interpolants and weights. Note that this + // will yield different results than if projecting the centre of output texels. + + const unsigned int inputWidth = image.GetWidth(); + const unsigned int inputHeight = image.GetHeight(); + + // How much to scale pixel coordinates in the output image to get the corresponding pixel coordinates + // in the input image. + const float scaleY = boost::numeric_cast(inputHeight) / boost::numeric_cast(outputHeight); + const float scaleX = boost::numeric_cast(inputWidth) / boost::numeric_cast(outputWidth); + + uint8_t rgb_x0y0[3]; + uint8_t rgb_x1y0[3]; + uint8_t rgb_x0y1[3]; + uint8_t rgb_x1y1[3]; + + for (unsigned int y = 0; y < outputHeight; ++y) + { + // Corresponding real-valued height coordinate in input image. + const float iy = boost::numeric_cast(y) * scaleY; + + // Discrete height coordinate of top-left texel (in the 2x2 texel area used for interpolation). + const float fiy = floorf(iy); + const unsigned int y0 = boost::numeric_cast(fiy); + + // Interpolation weight (range [0,1]) + const float yw = iy - fiy; + + for (unsigned int x = 0; x < outputWidth; ++x) + { + // Real-valued and discrete width coordinates in input image. + const float ix = boost::numeric_cast(x) * scaleX; + const float fix = floorf(ix); + const unsigned int x0 = boost::numeric_cast(fix); + + // Interpolation weight (range [0,1]). + const float xw = ix - fix; + + // Discrete width/height coordinates of texels below and to the right of (x0, y0). + const unsigned int x1 = std::min(x0 + 1, inputWidth - 1u); + const unsigned int y1 = std::min(y0 + 1, inputHeight - 1u); + + std::tie(rgb_x0y0[0], rgb_x0y0[1], rgb_x0y0[2]) = image.GetPixelAs3Channels(x0, y0); + std::tie(rgb_x1y0[0], rgb_x1y0[1], rgb_x1y0[2]) = image.GetPixelAs3Channels(x1, y0); + std::tie(rgb_x0y1[0], rgb_x0y1[1], rgb_x0y1[2]) = image.GetPixelAs3Channels(x0, y1); + std::tie(rgb_x1y1[0], rgb_x1y1[1], rgb_x1y1[2]) = image.GetPixelAs3Channels(x1, y1); + + for (unsigned c=0; c<3; ++c) + { + const float ly0 = Lerp(float(rgb_x0y0[c]), float(rgb_x1y0[c]), xw); + const float ly1 = Lerp(float(rgb_x0y1[c]), float(rgb_x1y1[c]), xw); + const float l = Lerp(ly0, ly1, yw); + PutData(out, outputWidth, x, y, c, ((l/255.0f) - mean[c])/stddev[c]); + } + } + } + return out; +} + } // namespace InferenceTestImage::InferenceTestImage(char const* filePath) @@ -94,42 +178,70 @@ std::tuple InferenceTestImage::GetPixelAs3Channels(un return std::make_tuple(outPixelData[0], outPixelData[1], outPixelData[2]); } -void InferenceTestImage::Resize(unsigned int newWidth, unsigned int newHeight) -{ - if (newWidth == 0 || newHeight == 0) - { - throw InferenceTestImageResizeFailed(boost::str(boost::format("None of the dimensions passed to a resize " - "operation can be zero. Requested width: %1%. Requested height: %2%.") % newWidth % newHeight)); - } - - if (newWidth == m_Width && newHeight == m_Height) - { - // nothing to do - return; - } +void InferenceTestImage::StbResize(InferenceTestImage& im, const unsigned int newWidth, const unsigned int newHeight) +{ std::vector newData; - newData.resize(newWidth * newHeight * GetNumChannels() * GetSingleElementSizeInBytes()); + newData.resize(newWidth * newHeight * im.GetNumChannels() * im.GetSingleElementSizeInBytes()); // boost::numeric_cast<>() is used for user-provided data (protecting about overflows). - // static_cast<> ok for internal data (assumes that, when internal data was originally provided by a user, + // static_cast<> is ok for internal data (assumes that, when internal data was originally provided by a user, // a boost::numeric_cast<>() handled the conversion). const int nW = boost::numeric_cast(newWidth); const int nH = boost::numeric_cast(newHeight); - const int w = static_cast(GetWidth()); - const int h = static_cast(GetHeight()); - const int numChannels = static_cast(GetNumChannels()); + const int w = static_cast(im.GetWidth()); + const int h = static_cast(im.GetHeight()); + const int numChannels = static_cast(im.GetNumChannels()); - const int res = stbir_resize_uint8(m_Data.data(), w, h, 0, newData.data(), nW, nH, 0, numChannels); + const int res = stbir_resize_uint8(im.m_Data.data(), w, h, 0, newData.data(), nW, nH, 0, numChannels); if (res == 0) { throw InferenceTestImageResizeFailed("The resizing operation failed"); } - m_Data.swap(newData); - m_Width = newWidth; - m_Height = newHeight; + im.m_Data.swap(newData); + im.m_Width = newWidth; + im.m_Height = newHeight; +} + +std::vector InferenceTestImage::Resize(unsigned int newWidth, + unsigned int newHeight, + const armnn::CheckLocation& location, + const ResizingMethods meth, + const std::array& mean, + const std::array& stddev) +{ + std::vector out; + if (newWidth == 0 || newHeight == 0) + { + throw InferenceTestImageResizeFailed(boost::str(boost::format("None of the dimensions passed to a resize " + "operation can be zero. Requested width: %1%. Requested height: %2%.") % newWidth % newHeight)); + } + + if (newWidth == m_Width && newHeight == m_Height) + { + // Nothing to do. + return out; + } + + switch (meth) { + case ResizingMethods::STB: + { + StbResize(*this, newWidth, newHeight); + break; + } + case ResizingMethods::BilinearAndNormalized: + { + out = ResizeBilinearAndNormalize(*this, newWidth, newHeight, mean, stddev); + break; + } + default: + throw InferenceTestImageResizeFailed(boost::str( + boost::format("Unknown resizing method asked ArmNN only supports {STB, BilinearAndNormalized} %1%") + % location.AsString())); + } + return out; } void InferenceTestImage::Write(WriteFormat format, const char* filePath) const @@ -252,4 +364,4 @@ std::vector GetImageDataAsNormalizedFloats(ImageChannelLayout layout, } return imageData; -} \ No newline at end of file +} diff --git a/tests/InferenceTestImage.hpp b/tests/InferenceTestImage.hpp index 34403c0dda..657ea04c7b 100644 --- a/tests/InferenceTestImage.hpp +++ b/tests/InferenceTestImage.hpp @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include @@ -57,6 +58,13 @@ public: Tga }; + // Common names used to identify a channel in a pixel. + enum class ResizingMethods + { + STB, + BilinearAndNormalized, + }; + explicit InferenceTestImage(const char* filePath); InferenceTestImage(InferenceTestImage&&) = delete; @@ -76,7 +84,16 @@ public: // of the tuple corresponds to the Red channel, whereas the last element is the Blue channel). std::tuple GetPixelAs3Channels(unsigned int x, unsigned int y) const; - void Resize(unsigned int newWidth, unsigned int newHeight); + void StbResize(InferenceTestImage& im, const unsigned int newWidth, const unsigned int newHeight); + + + std::vector Resize(unsigned int newWidth, + unsigned int newHeight, + const armnn::CheckLocation& location, + const ResizingMethods meth = ResizingMethods::STB, + const std::array& mean = {{0.0, 0.0, 0.0}}, + const std::array& stddev = {{1.0, 1.0, 1.0}}); + void Write(WriteFormat format, const char* filePath) const; private: @@ -91,7 +108,7 @@ private: unsigned int m_NumChannels; }; -// Common names used to identify a channel in a pixel +// Common names used to identify a channel in a pixel. enum class ImageChannel { R, @@ -99,7 +116,7 @@ enum class ImageChannel B }; -// Channel layouts handled by the test framework +// Channel layouts handled by the test framework. enum class ImageChannelLayout { Rgb, @@ -112,7 +129,7 @@ enum class ImageChannelLayout std::vector GetImageDataInArmNnLayoutAsNormalizedFloats(ImageChannelLayout layout, const InferenceTestImage& image); -// Reads the contents of an inference test image as 3-channel pixels whose value is the result of subtracting the mean +// Reads the contents of an inference test image as 3-channel pixels, whose value is the result of subtracting the mean // from the values in the original image. Channel data is stored according to the ArmNN layout (CHW). The order in // which channels appear in the resulting vector is defined by the provided layout. The order of the channels of the // provided mean should also match the given layout. diff --git a/tests/MnistDatabase.cpp b/tests/MnistDatabase.cpp index 5c10b0c2b4..2ca39ef6de 100644 --- a/tests/MnistDatabase.cpp +++ b/tests/MnistDatabase.cpp @@ -47,7 +47,7 @@ std::unique_ptr MnistDatabase::GetTestCaseData(uns unsigned int magic, num, row, col; - // check the files have the correct header + // Checks the files have the correct header. imageStream.read(reinterpret_cast(&magic), sizeof(magic)); if (magic != 0x03080000) { @@ -61,8 +61,8 @@ std::unique_ptr MnistDatabase::GetTestCaseData(uns return nullptr; } - // Endian swap image and label file - All the integers in the files are stored in MSB first(high endian) format, - // hence need to flip the bytes of the header if using it on Intel processors or low-endian machines + // Endian swaps the image and label file - all the integers in the files are stored in MSB first(high endian) + // format, hence it needs to flip the bytes of the header if using it on Intel processors or low-endian machines labelStream.read(reinterpret_cast(&num), sizeof(num)); imageStream.read(reinterpret_cast(&num), sizeof(num)); EndianSwap(num); @@ -71,7 +71,7 @@ std::unique_ptr MnistDatabase::GetTestCaseData(uns imageStream.read(reinterpret_cast(&col), sizeof(col)); EndianSwap(col); - // read image and label into memory + // Reads image and label into memory. imageStream.seekg(testCaseId * g_kMnistImageByteSize, std::ios_base::cur); imageStream.read(reinterpret_cast(&I[0]), g_kMnistImageByteSize); labelStream.seekg(testCaseId, std::ios_base::cur); diff --git a/tests/MnistDatabase.hpp b/tests/MnistDatabase.hpp index 281b708589..b1336bcef8 100644 --- a/tests/MnistDatabase.hpp +++ b/tests/MnistDatabase.hpp @@ -12,7 +12,8 @@ class MnistDatabase { public: - using TTestCaseData = ClassifierTestCaseData; + using DataType = float; + using TTestCaseData = ClassifierTestCaseData; explicit MnistDatabase(const std::string& binaryFileDirectory, bool scaleValues = false); std::unique_ptr GetTestCaseData(unsigned int testCaseId); diff --git a/tests/MobileNetDatabase.cpp b/tests/MobileNetDatabase.cpp deleted file mode 100644 index 66f297c502..0000000000 --- a/tests/MobileNetDatabase.cpp +++ /dev/null @@ -1,133 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// See LICENSE file in the project root for full license information. -// -#include "InferenceTestImage.hpp" -#include "MobileNetDatabase.hpp" - -#include -#include -#include - -#include -#include -#include - -namespace -{ - -inline float Lerp(float a, float b, float w) -{ - return w * b + (1.f - w) * a; -} - -inline void PutData(std::vector & data, - const unsigned int width, - const unsigned int x, - const unsigned int y, - const unsigned int c, - float value) -{ - data[(3*((y*width)+x)) + c] = value; -} - -std::vector -ResizeBilinearAndNormalize(const InferenceTestImage & image, - const unsigned int outputWidth, - const unsigned int outputHeight) -{ - std::vector out; - out.resize(outputWidth * outputHeight * 3); - - // We follow the definition of TensorFlow and AndroidNN: The top-left corner of a texel in the output - // image is projected into the input image to figure out the interpolants and weights. Note that this - // will yield different results than if projecting the centre of output texels. - - const unsigned int inputWidth = image.GetWidth(); - const unsigned int inputHeight = image.GetHeight(); - - // How much to scale pixel coordinates in the output image to get the corresponding pixel coordinates - // in the input image - const float scaleY = boost::numeric_cast(inputHeight) / boost::numeric_cast(outputHeight); - const float scaleX = boost::numeric_cast(inputWidth) / boost::numeric_cast(outputWidth); - - uint8_t rgb_x0y0[3]; - uint8_t rgb_x1y0[3]; - uint8_t rgb_x0y1[3]; - uint8_t rgb_x1y1[3]; - - for (unsigned int y = 0; y < outputHeight; ++y) - { - // Corresponding real-valued height coordinate in input image - const float iy = boost::numeric_cast(y) * scaleY; - - // Discrete height coordinate of top-left texel (in the 2x2 texel area used for interpolation) - const float fiy = floorf(iy); - const unsigned int y0 = boost::numeric_cast(fiy); - - // Interpolation weight (range [0,1]) - const float yw = iy - fiy; - - for (unsigned int x = 0; x < outputWidth; ++x) - { - // Real-valued and discrete width coordinates in input image - const float ix = boost::numeric_cast(x) * scaleX; - const float fix = floorf(ix); - const unsigned int x0 = boost::numeric_cast(fix); - - // Interpolation weight (range [0,1]) - const float xw = ix - fix; - - // Discrete width/height coordinates of texels below and to the right of (x0, y0) - const unsigned int x1 = std::min(x0 + 1, inputWidth - 1u); - const unsigned int y1 = std::min(y0 + 1, inputHeight - 1u); - - std::tie(rgb_x0y0[0], rgb_x0y0[1], rgb_x0y0[2]) = image.GetPixelAs3Channels(x0, y0); - std::tie(rgb_x1y0[0], rgb_x1y0[1], rgb_x1y0[2]) = image.GetPixelAs3Channels(x1, y0); - std::tie(rgb_x0y1[0], rgb_x0y1[1], rgb_x0y1[2]) = image.GetPixelAs3Channels(x0, y1); - std::tie(rgb_x1y1[0], rgb_x1y1[1], rgb_x1y1[2]) = image.GetPixelAs3Channels(x1, y1); - - for (unsigned c=0; c<3; ++c) - { - const float ly0 = Lerp(float(rgb_x0y0[c]), float(rgb_x1y0[c]), xw); - const float ly1 = Lerp(float(rgb_x0y1[c]), float(rgb_x1y1[c]), xw); - const float l = Lerp(ly0, ly1, yw); - PutData(out, outputWidth, x, y, c, l/255.0f); - } - } - } - - return out; -} - -} // end of anonymous namespace - - -MobileNetDatabase::MobileNetDatabase(const std::string& binaryFileDirectory, - unsigned int width, - unsigned int height, - const std::vector& imageSet) -: m_BinaryDirectory(binaryFileDirectory) -, m_Height(height) -, m_Width(width) -, m_ImageSet(imageSet) -{ -} - -std::unique_ptr -MobileNetDatabase::GetTestCaseData(unsigned int testCaseId) -{ - testCaseId = testCaseId % boost::numeric_cast(m_ImageSet.size()); - const ImageSet& imageSet = m_ImageSet[testCaseId]; - const std::string fullPath = m_BinaryDirectory + imageSet.first; - - InferenceTestImage image(fullPath.c_str()); - - // this ResizeBilinear result is closer to the tensorflow one than STB. - // there is still some difference though, but the inference results are - // similar to tensorflow for MobileNet - std::vector resized(ResizeBilinearAndNormalize(image, m_Width, m_Height)); - - const unsigned int label = imageSet.second; - return std::make_unique(label, std::move(resized)); -} diff --git a/tests/MobileNetDatabase.hpp b/tests/MobileNetDatabase.hpp deleted file mode 100644 index eb34260e90..0000000000 --- a/tests/MobileNetDatabase.hpp +++ /dev/null @@ -1,36 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// See LICENSE file in the project root for full license information. -// -#pragma once - -#include "ClassifierTestCaseData.hpp" - -#include -#include -#include -#include - -using ImageSet = std::pair; - -class MobileNetDatabase -{ -public: - using TTestCaseData = ClassifierTestCaseData; - - explicit MobileNetDatabase(const std::string& binaryFileDirectory, - unsigned int width, - unsigned int height, - const std::vector& imageSet); - - std::unique_ptr GetTestCaseData(unsigned int testCaseId); - -private: - unsigned int GetNumImageElements() const { return 3 * m_Width * m_Height; } - unsigned int GetNumImageBytes() const { return 4 * GetNumImageElements(); } - - std::string m_BinaryDirectory; - unsigned int m_Height; - unsigned int m_Width; - const std::vector m_ImageSet; -}; \ No newline at end of file diff --git a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp index 37138f4a78..ca6ff45b1b 100644 --- a/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp +++ b/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp @@ -30,25 +30,26 @@ int main(int argc, char* argv[]) try { - // Configure logging for both the ARMNN library and this test program + // Configures logging for both the ARMNN library and this test program. armnn::ConfigureLogging(true, true, level); armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level); namespace po = boost::program_options; - armnn::Compute computeDevice; + std::vector computeDevice; std::string modelDir; std::string dataDir; po::options_description desc("Options"); try { - // Add generic options needed for all inference tests + // Adds generic options needed for all inference tests. desc.add_options() ("help", "Display help messages") ("model-dir,m", po::value(&modelDir)->required(), "Path to directory containing the Cifar10 model file") - ("compute,c", po::value(&computeDevice)->default_value(armnn::Compute::CpuAcc), + ("compute,c", po::value>(&computeDevice)->default_value + ({armnn::Compute::CpuAcc, armnn::Compute::CpuRef}), "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc") ("data-dir,d", po::value(&dataDir)->required(), "Path to directory containing the Cifar10 test data"); @@ -91,9 +92,10 @@ int main(int argc, char* argv[]) string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel"; // Create runtime - armnn::IRuntimePtr runtime(armnn::IRuntime::Create(computeDevice)); + armnn::IRuntime::CreationOptions options; + armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options)); - // Load networks + // Loads networks. armnn::Status status; struct Net { @@ -116,14 +118,14 @@ int main(int argc, char* argv[]) const int networksCount = 4; for (int i = 0; i < networksCount; ++i) { - // Create a network from a file on disk + // Creates a network from a file on the disk. armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" }); - // optimize the network + // Optimizes the network. armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr); try { - optimizedNet = armnn::Optimize(*network, runtime->GetDeviceSpec()); + optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec()); } catch (armnn::Exception& e) { @@ -133,7 +135,7 @@ int main(int argc, char* argv[]) return 1; } - // Load the network into the runtime + // Loads the network into the runtime. armnn::NetworkId networkId; status = runtime->LoadNetwork(networkId, std::move(optimizedNet)); if (status == armnn::Status::Failure) @@ -147,7 +149,7 @@ int main(int argc, char* argv[]) parser->GetNetworkOutputBindingInfo("prob")); } - // Load a test case and test inference + // Loads a test case and tests inference. if (!ValidateDirectory(dataDir)) { return 1; @@ -156,10 +158,10 @@ int main(int argc, char* argv[]) for (unsigned int i = 0; i < 3; ++i) { - // Load test case data (including image data) + // Loads test case data (including image data). std::unique_ptr testCaseData = cifar10.GetTestCaseData(i); - // Test inference + // Tests inference. std::vector> outputs(networksCount); for (unsigned int k = 0; k < networksCount; ++k) @@ -174,7 +176,7 @@ int main(int argc, char* argv[]) } } - // Compare outputs + // Compares outputs. for (unsigned int k = 1; k < networksCount; ++k) { if (!std::equal(outputs[0].begin(), outputs[0].end(), outputs[k].begin(), outputs[k].end())) diff --git a/tests/OnnxMnist-Armnn/OnnxMnist-Armnn.cpp b/tests/OnnxMnist-Armnn/OnnxMnist-Armnn.cpp new file mode 100644 index 0000000000..a372f54ddb --- /dev/null +++ b/tests/OnnxMnist-Armnn/OnnxMnist-Armnn.cpp @@ -0,0 +1,39 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#include "../InferenceTest.hpp" +#include "../MnistDatabase.hpp" +#include "armnnOnnxParser/IOnnxParser.hpp" + +int main(int argc, char* argv[]) +{ + armnn::TensorShape inputTensorShape({ 1, 1, 28, 28 }); + + int retVal = EXIT_FAILURE; + try + { + using DataType = float; + using DatabaseType = MnistDatabase; + using ParserType = armnnOnnxParser::IOnnxParser; + using ModelType = InferenceModel; + + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. + retVal = armnn::test::ClassifierInferenceTestMain( + argc, argv, "mnist_onnx.onnx", true, + "Input3", "Plus214_Output_0", { 0, 1, 2, 3, 4}, + [](const char* dataDir, const ModelType&) { + return DatabaseType(dataDir, true); + }, + &inputTensorShape); + } + catch (const std::exception& e) + { + // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an + // exception of type std::length_error. + // Using stderr instead in this context as there is no point in nesting try-catch blocks here. + std::cerr << "WARNING: OnnxMnist-Armnn: An error has occurred when running " + "the classifier inference tests: " << e.what() << std::endl; + } + return retVal; +} diff --git a/tests/OnnxMnist-Armnn/Validation.txt b/tests/OnnxMnist-Armnn/Validation.txt new file mode 100644 index 0000000000..8ddde9340a --- /dev/null +++ b/tests/OnnxMnist-Armnn/Validation.txt @@ -0,0 +1,1000 @@ +7 +2 +1 +0 +4 +1 +4 +9 +5 +9 +0 +6 +9 +0 +1 +5 +9 +7 +3 +4 +9 +6 +6 +5 +4 +0 +7 +4 +0 +1 +3 +1 +3 +4 +7 +2 +7 +1 +2 +1 +1 +7 +4 +2 +3 +5 +1 +2 +4 +4 +6 +3 +5 +5 +6 +0 +4 +1 +9 +5 +7 +8 +5 +3 +7 +4 +6 +4 +3 +0 +7 +0 +2 +9 +1 +7 +3 +2 +9 +7 +7 +6 +2 +7 +8 +4 +7 +3 +6 +1 +3 +6 +4 +3 +1 +4 +1 +7 +6 +9 +6 +0 +5 +4 +9 +9 +2 +1 +9 +4 +8 +7 +3 +9 +7 +4 +4 +4 +9 +2 +5 +4 +7 +6 +7 +9 +0 +5 +8 +5 +6 +6 +5 +7 +8 +1 +0 +1 +6 +4 +6 +7 +3 +1 +7 +1 +8 +2 +0 +2 +9 +9 +5 +5 +1 +5 +6 +0 +3 +4 +4 +6 +5 +4 +6 +5 +4 +5 +1 +4 +4 +7 +2 +3 +2 +7 +1 +8 +1 +8 +1 +8 +5 +0 +8 +9 +2 +5 +0 +1 +1 +1 +0 +9 +0 +3 +1 +6 +4 +2 +3 +6 +1 +1 +1 +3 +9 +5 +2 +9 +4 +5 +9 +3 +9 +0 +3 +6 +5 +5 +7 +2 +2 +7 +1 +2 +8 +4 +1 +7 +3 +3 +8 +8 +7 +9 +2 +2 +4 +1 +5 +9 +8 +7 +2 +3 +0 +2 +4 +2 +4 +1 +9 +5 +7 +7 +2 +8 +2 +6 +8 +5 +7 +7 +9 +1 +0 +1 +8 +0 +3 +0 +1 +9 +9 +4 +1 +8 +2 +1 +2 +9 +7 +5 +9 +2 +6 +4 +1 +5 +8 +2 +9 +2 +0 +4 +0 +0 +2 +8 +4 +7 +1 +2 +4 +0 +2 +7 +4 +3 +3 +0 +0 +3 +1 +9 +6 +5 +2 +5 +9 +2 +9 +3 +0 +4 +2 +0 +7 +1 +1 +2 +1 +5 +3 +3 +9 +7 +8 +6 +5 +6 +1 +3 +8 +1 +0 +5 +1 +3 +1 +5 +5 +6 +1 +8 +5 +1 +7 +9 +4 +6 +2 +2 +5 +0 +6 +5 +6 +3 +7 +2 +0 +8 +8 +5 +4 +1 +1 +4 +0 +3 +3 +7 +6 +1 +6 +2 +1 +9 +2 +8 +6 +1 +9 +5 +2 +5 +4 +4 +2 +8 +3 +8 +2 +4 +5 +0 +3 +1 +7 +7 +5 +7 +9 +7 +1 +9 +2 +1 +4 +2 +9 +2 +0 +4 +9 +1 +4 +8 +1 +8 +4 +5 +9 +8 +8 +3 +7 +6 +0 +0 +3 +0 +2 +0 +6 +4 +9 +5 +3 +3 +2 +3 +9 +1 +2 +6 +8 +0 +5 +6 +6 +6 +3 +8 +8 +2 +7 +5 +8 +9 +6 +1 +8 +4 +1 +2 +5 +9 +1 +9 +7 +5 +4 +0 +8 +9 +9 +1 +0 +5 +2 +3 +7 +0 +9 +4 +0 +6 +3 +9 +5 +2 +1 +3 +1 +3 +6 +5 +7 +4 +2 +2 +6 +3 +2 +6 +5 +4 +8 +9 +7 +1 +3 +0 +3 +8 +3 +1 +9 +3 +4 +4 +6 +4 +2 +1 +8 +2 +5 +4 +8 +8 +4 +0 +0 +2 +3 +2 +7 +7 +0 +8 +7 +4 +4 +7 +9 +6 +9 +0 +9 +8 +0 +4 +6 +0 +6 +3 +5 +4 +8 +3 +3 +9 +3 +3 +3 +7 +8 +0 +8 +2 +1 +7 +0 +6 +5 +4 +3 +8 +0 +9 +6 +3 +8 +0 +9 +9 +6 +8 +6 +8 +5 +7 +8 +6 +0 +2 +4 +0 +2 +2 +3 +1 +9 +7 +5 +8 +0 +8 +4 +6 +2 +6 +7 +9 +3 +2 +9 +8 +2 +2 +9 +2 +7 +3 +5 +9 +1 +8 +0 +2 +0 +5 +2 +1 +3 +7 +6 +7 +1 +2 +5 +8 +0 +3 +7 +1 +4 +0 +9 +1 +8 +6 +7 +7 +4 +3 +4 +9 +1 +9 +5 +1 +7 +3 +9 +7 +6 +9 +1 +3 +2 +8 +3 +3 +6 +7 +2 +8 +5 +8 +5 +1 +1 +4 +4 +3 +1 +0 +7 +7 +0 +7 +9 +4 +4 +8 +5 +5 +4 +0 +8 +2 +1 +0 +8 +4 +5 +0 +4 +0 +6 +1 +5 +3 +2 +6 +7 +2 +6 +9 +3 +1 +4 +6 +2 +5 +9 +2 +0 +6 +2 +1 +7 +3 +4 +1 +0 +5 +4 +3 +1 +1 +7 +4 +9 +9 +4 +8 +4 +0 +2 +4 +5 +1 +1 +6 +4 +7 +1 +9 +4 +2 +4 +1 +5 +5 +3 +8 +3 +1 +4 +5 +6 +8 +9 +4 +1 +5 +3 +8 +0 +3 +2 +5 +1 +2 +8 +3 +4 +4 +0 +8 +8 +3 +3 +1 +7 +3 +5 +9 +6 +3 +2 +6 +1 +3 +6 +0 +7 +2 +1 +7 +1 +4 +2 +4 +2 +1 +7 +9 +6 +1 +1 +2 +4 +8 +1 +7 +7 +4 +7 +0 +7 +3 +1 +3 +1 +0 +7 +7 +0 +3 +5 +5 +2 +7 +6 +6 +9 +2 +8 +3 +5 +2 +2 +5 +6 +0 +8 +2 +9 +2 +8 +8 +8 +8 +7 +4 +7 +3 +0 +6 +6 +3 +2 +1 +3 +2 +2 +9 +3 +0 +0 +5 +7 +8 +1 +4 +4 +6 +0 +2 +9 +1 +4 +7 +4 +7 +3 +9 +8 +8 +4 +7 +1 +2 +1 +2 +2 +3 +2 +3 +2 +3 +9 +1 +7 +4 +0 +3 +5 +5 +8 +6 +3 +2 +6 +7 +6 +6 +3 +2 +7 +9 +1 +1 +7 +5 +6 +4 +9 +5 +1 +3 +3 +4 +7 +8 +9 +1 +1 +0 +9 +1 +4 +4 +5 +4 +0 +6 +2 +2 +3 +1 +5 +1 +2 +0 +3 +8 +1 +2 +6 +7 +1 +6 +2 +3 +9 +0 +1 +2 +2 +0 +8 +9 diff --git a/tests/OnnxMobileNet-Armnn/OnnxMobileNet-Armnn.cpp b/tests/OnnxMobileNet-Armnn/OnnxMobileNet-Armnn.cpp new file mode 100644 index 0000000000..0d2d937469 --- /dev/null +++ b/tests/OnnxMobileNet-Armnn/OnnxMobileNet-Armnn.cpp @@ -0,0 +1,60 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#include "../InferenceTest.hpp" +#include "../ImagePreprocessor.hpp" +#include "armnnOnnxParser/IOnnxParser.hpp" + +int main(int argc, char* argv[]) +{ + int retVal = EXIT_FAILURE; + try + { + // Coverity fix: The following code may throw an exception of type std::length_error. + std::vector imageSet = + { + {"Dog.jpg", 208}, + {"Cat.jpg", 281}, + {"shark.jpg", 2}, + }; + + armnn::TensorShape inputTensorShape({ 1, 3, 224, 224 }); + + using DataType = float; + using DatabaseType = ImagePreprocessor; + using ParserType = armnnOnnxParser::IOnnxParser; + using ModelType = InferenceModel; + + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. + retVal = armnn::test::ClassifierInferenceTestMain( + argc, argv, + "mobilenetv2-1.0.onnx", // model name + true, // model is binary + "data", "mobilenetv20_output_flatten0_reshape0", // input and output tensor names + { 0, 1, 2 }, // test images to test with as above + [&imageSet](const char* dataDir, const ModelType&) { + // This creates create a 1, 3, 224, 224 normalized input with mean and stddev to pass to Armnn + return DatabaseType( + dataDir, + 224, + 224, + imageSet, + 1.0, // scale + 0, // offset + {{0.485f, 0.456f, 0.406f}}, // mean + {{0.229f, 0.224f, 0.225f}}, // stddev + DatabaseType::DataFormat::NCHW); // format + }, + &inputTensorShape); + } + catch (const std::exception& e) + { + // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an + // exception of type std::length_error. + // Using stderr instead in this context as there is no point in nesting try-catch blocks here. + std::cerr << "WARNING: OnnxMobileNet-Armnn: An error has occurred when running " + "the classifier inference tests: " << e.what() << std::endl; + } + return retVal; +} diff --git a/tests/OnnxMobileNet-Armnn/Validation.txt b/tests/OnnxMobileNet-Armnn/Validation.txt new file mode 100644 index 0000000000..ccadd10253 --- /dev/null +++ b/tests/OnnxMobileNet-Armnn/Validation.txt @@ -0,0 +1,201 @@ +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 +208 +281 +2 \ No newline at end of file diff --git a/tests/OnnxMobileNet-Armnn/labels.txt b/tests/OnnxMobileNet-Armnn/labels.txt new file mode 100644 index 0000000000..d74ff557dd --- /dev/null +++ b/tests/OnnxMobileNet-Armnn/labels.txt @@ -0,0 +1,1001 @@ +0:background +1:tench, Tinca tinca +2:goldfish, Carassius auratus +3:great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias +4:tiger shark, Galeocerdo cuvieri +5:hammerhead, hammerhead shark +6:electric ray, crampfish, numbfish, torpedo +7:stingray +8:cock +9:hen +10:ostrich, Struthio camelus +11:brambling, Fringilla montifringilla +12:goldfinch, Carduelis carduelis +13:house finch, linnet, Carpodacus mexicanus +14:junco, snowbird +15:indigo bunting, indigo finch, indigo bird, Passerina cyanea +16:robin, American robin, Turdus migratorius +17:bulbul +18:jay +19:magpie +20:chickadee +21:water ouzel, dipper +22:kite +23:bald eagle, American eagle, Haliaeetus leucocephalus +24:vulture +25:great grey owl, great gray owl, Strix nebulosa +26:European fire salamander, Salamandra salamandra +27:common newt, Triturus vulgaris +28:eft +29:spotted salamander, Ambystoma maculatum +30:axolotl, mud puppy, Ambystoma mexicanum +31:bullfrog, Rana catesbeiana +32:tree frog, tree-frog +33:tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui +34:loggerhead, loggerhead turtle, Caretta caretta +35:leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea +36:mud turtle +37:terrapin +38:box turtle, box tortoise +39:banded gecko +40:common iguana, iguana, Iguana iguana +41:American chameleon, anole, Anolis carolinensis +42:whiptail, whiptail lizard +43:agama +44:frilled lizard, Chlamydosaurus kingi +45:alligator lizard +46:Gila monster, Heloderma suspectum +47:green lizard, Lacerta viridis +48:African chameleon, Chamaeleo chamaeleon +49:Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis +50:African crocodile, Nile crocodile, Crocodylus niloticus +51:American alligator, Alligator mississipiensis +52:triceratops +53:thunder snake, worm snake, Carphophis amoenus +54:ringneck snake, ring-necked snake, ring snake +55:hognose snake, puff adder, sand viper +56:green snake, grass snake +57:king snake, kingsnake +58:garter snake, grass snake +59:water snake +60:vine snake +61:night snake, Hypsiglena torquata +62:boa constrictor, Constrictor constrictor +63:rock python, rock snake, Python sebae +64:Indian cobra, Naja naja +65:green mamba +66:sea snake +67:horned viper, cerastes, sand viper, horned asp, Cerastes cornutus +68:diamondback, diamondback rattlesnake, Crotalus adamanteus +69:sidewinder, horned rattlesnake, Crotalus cerastes +70:trilobite +71:harvestman, daddy longlegs, Phalangium opilio +72:scorpion +73:black and gold garden spider, Argiope aurantia +74:barn spider, Araneus cavaticus +75:garden spider, Aranea diademata +76:black widow, Latrodectus mactans +77:tarantula +78:wolf spider, hunting spider +79:tick +80:centipede +81:black grouse +82:ptarmigan +83:ruffed grouse, partridge, Bonasa umbellus +84:prairie chicken, prairie grouse, prairie fowl +85:peacock +86:quail +87:partridge +88:African grey, African gray, Psittacus erithacus +89:macaw +90:sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita +91:lorikeet +92:coucal +93:bee eater +94:hornbill +95:hummingbird +96:jacamar +97:toucan +98:drake +99:red-breasted merganser, Mergus serrator +100:goose +101:black swan, Cygnus atratus +102:tusker +103:echidna, spiny anteater, anteater +104:platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus +105:wallaby, brush kangaroo +106:koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus +107:wombat +108:jellyfish +109:sea anemone, anemone +110:brain coral +111:flatworm, platyhelminth +112:nematode, nematode worm, roundworm +113:conch +114:snail +115:slug +116:sea slug, nudibranch +117:chiton, coat-of-mail shell, sea cradle, polyplacophore +118:chambered nautilus, pearly nautilus, nautilus +119:Dungeness crab, Cancer magister +120:rock crab, Cancer irroratus +121:fiddler crab +122:king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica +123:American lobster, Northern lobster, Maine lobster, Homarus americanus +124:spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish +125:crayfish, crawfish, crawdad, crawdaddy +126:hermit crab +127:isopod +128:white stork, Ciconia ciconia +129:black stork, Ciconia nigra +130:spoonbill +131:flamingo +132:little blue heron, Egretta caerulea +133:American egret, great white heron, Egretta albus +134:bittern +135:crane +136:limpkin, Aramus pictus +137:European gallinule, Porphyrio porphyrio +138:American coot, marsh hen, mud hen, water hen, Fulica americana +139:bustard +140:ruddy turnstone, Arenaria interpres +141:red-backed sandpiper, dunlin, Erolia alpina +142:redshank, Tringa totanus +143:dowitcher +144:oystercatcher, oyster catcher +145:pelican +146:king penguin, Aptenodytes patagonica +147:albatross, mollymawk +148:grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus +149:killer whale, killer, orca, grampus, sea wolf, Orcinus orca +150:dugong, Dugong dugon +151:sea lion +152:Chihuahua +153:Japanese spaniel +154:Maltese dog, Maltese terrier, Maltese +155:Pekinese, Pekingese, Peke +156:Shih-Tzu +157:Blenheim spaniel +158:papillon +159:toy terrier +160:Rhodesian ridgeback +161:Afghan hound, Afghan +162:basset, basset hound +163:beagle +164:bloodhound, sleuthhound +165:bluetick +166:black-and-tan coonhound +167:Walker hound, Walker foxhound +168:English foxhound +169:redbone +170:borzoi, Russian wolfhound +171:Irish wolfhound +172:Italian greyhound +173:whippet +174:Ibizan hound, Ibizan Podenco +175:Norwegian elkhound, elkhound +176:otterhound, otter hound +177:Saluki, gazelle hound +178:Scottish deerhound, deerhound +179:Weimaraner +180:Staffordshire bullterrier, Staffordshire bull terrier +181:American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier +182:Bedlington terrier +183:Border terrier +184:Kerry blue terrier +185:Irish terrier +186:Norfolk terrier +187:Norwich terrier +188:Yorkshire terrier +189:wire-haired fox terrier +190:Lakeland terrier +191:Sealyham terrier, Sealyham +192:Airedale, Airedale terrier +193:cairn, cairn terrier +194:Australian terrier +195:Dandie Dinmont, Dandie Dinmont terrier +196:Boston bull, Boston terrier +197:miniature schnauzer +198:giant schnauzer +199:standard schnauzer +200:Scotch terrier, Scottish terrier, Scottie +201:Tibetan terrier, chrysanthemum dog +202:silky terrier, Sydney silky +203:soft-coated wheaten terrier +204:West Highland white terrier +205:Lhasa, Lhasa apso +206:flat-coated retriever +207:curly-coated retriever +208:golden retriever +209:Labrador retriever +210:Chesapeake Bay retriever +211:German short-haired pointer +212:vizsla, Hungarian pointer +213:English setter +214:Irish setter, red setter +215:Gordon setter +216:Brittany spaniel +217:clumber, clumber spaniel +218:English springer, English springer spaniel +219:Welsh springer spaniel +220:cocker spaniel, English cocker spaniel, cocker +221:Sussex spaniel +222:Irish water spaniel +223:kuvasz +224:schipperke +225:groenendael +226:malinois +227:briard +228:kelpie +229:komondor +230:Old English sheepdog, bobtail +231:Shetland sheepdog, Shetland sheep dog, Shetland +232:collie +233:Border collie +234:Bouvier des Flandres, Bouviers des Flandres +235:Rottweiler +236:German shepherd, German shepherd dog, German police dog, alsatian +237:Doberman, Doberman pinscher +238:miniature pinscher +239:Greater Swiss Mountain dog +240:Bernese mountain dog +241:Appenzeller +242:EntleBucher +243:boxer +244:bull mastiff +245:Tibetan mastiff +246:French bulldog +247:Great Dane +248:Saint Bernard, St Bernard +249:Eskimo dog, husky +250:malamute, malemute, Alaskan malamute +251:Siberian husky +252:dalmatian, coach dog, carriage dog +253:affenpinscher, monkey pinscher, monkey dog +254:basenji +255:pug, pug-dog +256:Leonberg +257:Newfoundland, Newfoundland dog +258:Great Pyrenees +259:Samoyed, Samoyede +260:Pomeranian +261:chow, chow chow +262:keeshond +263:Brabancon griffon +264:Pembroke, Pembroke Welsh corgi +265:Cardigan, Cardigan Welsh corgi +266:toy poodle +267:miniature poodle +268:standard poodle +269:Mexican hairless +270:timber wolf, grey wolf, gray wolf, Canis lupus +271:white wolf, Arctic wolf, Canis lupus tundrarum +272:red wolf, maned wolf, Canis rufus, Canis niger +273:coyote, prairie wolf, brush wolf, Canis latrans +274:dingo, warrigal, warragal, Canis dingo +275:dhole, Cuon alpinus +276:African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus +277:hyena, hyaena +278:red fox, Vulpes vulpes +279:kit fox, Vulpes macrotis +280:Arctic fox, white fox, Alopex lagopus +281:grey fox, gray fox, Urocyon cinereoargenteus +282:tabby, tabby cat +283:tiger cat +284:Persian cat +285:Siamese cat, Siamese +286:Egyptian cat +287:cougar, puma, catamount, mountain lion, painter, panther, Felis concolor +288:lynx, catamount +289:leopard, Panthera pardus +290:snow leopard, ounce, Panthera uncia +291:jaguar, panther, Panthera onca, Felis onca +292:lion, king of beasts, Panthera leo +293:tiger, Panthera tigris +294:cheetah, chetah, Acinonyx jubatus +295:brown bear, bruin, Ursus arctos +296:American black bear, black bear, Ursus americanus, Euarctos americanus +297:ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus +298:sloth bear, Melursus ursinus, Ursus ursinus +299:mongoose +300:meerkat, mierkat +301:tiger beetle +302:ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle +303:ground beetle, carabid beetle +304:long-horned beetle, longicorn, longicorn beetle +305:leaf beetle, chrysomelid +306:dung beetle +307:rhinoceros beetle +308:weevil +309:fly +310:bee +311:ant, emmet, pismire +312:grasshopper, hopper +313:cricket +314:walking stick, walkingstick, stick insect +315:cockroach, roach +316:mantis, mantid +317:cicada, cicala +318:leafhopper +319:lacewing, lacewing fly +320:dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk +321:damselfly +322:admiral +323:ringlet, ringlet butterfly +324:monarch, monarch butterfly, milkweed butterfly, Danaus plexippus +325:cabbage butterfly +326:sulphur butterfly, sulfur butterfly +327:lycaenid, lycaenid butterfly +328:starfish, sea star +329:sea urchin +330:sea cucumber, holothurian +331:wood rabbit, cottontail, cottontail rabbit +332:hare +333:Angora, Angora rabbit +334:hamster +335:porcupine, hedgehog +336:fox squirrel, eastern fox squirrel, Sciurus niger +337:marmot +338:beaver +339:guinea pig, Cavia cobaya +340:sorrel +341:zebra +342:hog, pig, grunter, squealer, Sus scrofa +343:wild boar, boar, Sus scrofa +344:warthog +345:hippopotamus, hippo, river horse, Hippopotamus amphibius +346:ox +347:water buffalo, water ox, Asiatic buffalo, Bubalus bubalis +348:bison +349:ram, tup +350:bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis +351:ibex, Capra ibex +352:hartebeest +353:impala, Aepyceros melampus +354:gazelle +355:Arabian camel, dromedary, Camelus dromedarius +356:llama +357:weasel +358:mink +359:polecat, fitch, foulmart, foumart, Mustela putorius +360:black-footed ferret, ferret, Mustela nigripes +361:otter +362:skunk, polecat, wood pussy +363:badger +364:armadillo +365:three-toed sloth, ai, Bradypus tridactylus +366:orangutan, orang, orangutang, Pongo pygmaeus +367:gorilla, Gorilla gorilla +368:chimpanzee, chimp, Pan troglodytes +369:gibbon, Hylobates lar +370:siamang, Hylobates syndactylus, Symphalangus syndactylus +371:guenon, guenon monkey +372:patas, hussar monkey, Erythrocebus patas +373:baboon +374:macaque +375:langur +376:colobus, colobus monkey +377:proboscis monkey, Nasalis larvatus +378:marmoset +379:capuchin, ringtail, Cebus capucinus +380:howler monkey, howler +381:titi, titi monkey +382:spider monkey, Ateles geoffroyi +383:squirrel monkey, Saimiri sciureus +384:Madagascar cat, ring-tailed lemur, Lemur catta +385:indri, indris, Indri indri, Indri brevicaudatus +386:Indian elephant, Elephas maximus +387:African elephant, Loxodonta africana +388:lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens +389:giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca +390:barracouta, snoek +391:eel +392:coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch +393:rock beauty, Holocanthus tricolor +394:anemone fish +395:sturgeon +396:gar, garfish, garpike, billfish, Lepisosteus osseus +397:lionfish +398:puffer, pufferfish, blowfish, globefish +399:abacus +400:abaya +401:academic gown, academic robe, judge's robe +402:accordion, piano accordion, squeeze box +403:acoustic guitar +404:aircraft carrier, carrier, flattop, attack aircraft carrier +405:airliner +406:airship, dirigible +407:altar +408:ambulance +409:amphibian, amphibious vehicle +410:analog clock +411:apiary, bee house +412:apron +413:ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin +414:assault rifle, assault gun +415:backpack, back pack, knapsack, packsack, rucksack, haversack +416:bakery, bakeshop, bakehouse +417:balance beam, beam +418:balloon +419:ballpoint, ballpoint pen, ballpen, Biro +420:Band Aid +421:banjo +422:bannister, banister, balustrade, balusters, handrail +423:barbell +424:barber chair +425:barbershop +426:barn +427:barometer +428:barrel, cask +429:barrow, garden cart, lawn cart, wheelbarrow +430:baseball +431:basketball +432:bassinet +433:bassoon +434:bathing cap, swimming cap +435:bath towel +436:bathtub, bathing tub, bath, tub +437:beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon +438:beacon, lighthouse, beacon light, pharos +439:beaker +440:bearskin, busby, shako +441:beer bottle +442:beer glass +443:bell cote, bell cot +444:bib +445:bicycle-built-for-two, tandem bicycle, tandem +446:bikini, two-piece +447:binder, ring-binder +448:binoculars, field glasses, opera glasses +449:birdhouse +450:boathouse +451:bobsled, bobsleigh, bob +452:bolo tie, bolo, bola tie, bola +453:bonnet, poke bonnet +454:bookcase +455:bookshop, bookstore, bookstall +456:bottlecap +457:bow +458:bow tie, bow-tie, bowtie +459:brass, memorial tablet, plaque +460:brassiere, bra, bandeau +461:breakwater, groin, groyne, mole, bulwark, seawall, jetty +462:breastplate, aegis, egis +463:broom +464:bucket, pail +465:buckle +466:bulletproof vest +467:bullet train, bullet +468:butcher shop, meat market +469:cab, hack, taxi, taxicab +470:caldron, cauldron +471:candle, taper, wax light +472:cannon +473:canoe +474:can opener, tin opener +475:cardigan +476:car mirror +477:carousel, carrousel, merry-go-round, roundabout, whirligig +478:carpenter's kit, tool kit +479:carton +480:car wheel +481:cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM +482:cassette +483:cassette player +484:castle +485:catamaran +486:CD player +487:cello, violoncello +488:cellular telephone, cellular phone, cellphone, cell, mobile phone +489:chain +490:chainlink fence +491:chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour +492:chain saw, chainsaw +493:chest +494:chiffonier, commode +495:chime, bell, gong +496:china cabinet, china closet +497:Christmas stocking +498:church, church building +499:cinema, movie theater, movie theatre, movie house, picture palace +500:cleaver, meat cleaver, chopper +501:cliff dwelling +502:cloak +503:clog, geta, patten, sabot +504:cocktail shaker +505:coffee mug +506:coffeepot +507:coil, spiral, volute, whorl, helix +508:combination lock +509:computer keyboard, keypad +510:confectionery, confectionary, candy store +511:container ship, containership, container vessel +512:convertible +513:corkscrew, bottle screw +514:cornet, horn, trumpet, trump +515:cowboy boot +516:cowboy hat, ten-gallon hat +517:cradle +518:crane +519:crash helmet +520:crate +521:crib, cot +522:Crock Pot +523:croquet ball +524:crutch +525:cuirass +526:dam, dike, dyke +527:desk +528:desktop computer +529:dial telephone, dial phone +530:diaper, nappy, napkin +531:digital clock +532:digital watch +533:dining table, board +534:dishrag, dishcloth +535:dishwasher, dish washer, dishwashing machine +536:disk brake, disc brake +537:dock, dockage, docking facility +538:dogsled, dog sled, dog sleigh +539:dome +540:doormat, welcome mat +541:drilling platform, offshore rig +542:drum, membranophone, tympan +543:drumstick +544:dumbbell +545:Dutch oven +546:electric fan, blower +547:electric guitar +548:electric locomotive +549:entertainment center +550:envelope +551:espresso maker +552:face powder +553:feather boa, boa +554:file, file cabinet, filing cabinet +555:fireboat +556:fire engine, fire truck +557:fire screen, fireguard +558:flagpole, flagstaff +559:flute, transverse flute +560:folding chair +561:football helmet +562:forklift +563:fountain +564:fountain pen +565:four-poster +566:freight car +567:French horn, horn +568:frying pan, frypan, skillet +569:fur coat +570:garbage truck, dustcart +571:gasmask, respirator, gas helmet +572:gas pump, gasoline pump, petrol pump, island dispenser +573:goblet +574:go-kart +575:golf ball +576:golfcart, golf cart +577:gondola +578:gong, tam-tam +579:gown +580:grand piano, grand +581:greenhouse, nursery, glasshouse +582:grille, radiator grille +583:grocery store, grocery, food market, market +584:guillotine +585:hair slide +586:hair spray +587:half track +588:hammer +589:hamper +590:hand blower, blow dryer, blow drier, hair dryer, hair drier +591:hand-held computer, hand-held microcomputer +592:handkerchief, hankie, hanky, hankey +593:hard disc, hard disk, fixed disk +594:harmonica, mouth organ, harp, mouth harp +595:harp +596:harvester, reaper +597:hatchet +598:holster +599:home theater, home theatre +600:honeycomb +601:hook, claw +602:hoopskirt, crinoline +603:horizontal bar, high bar +604:horse cart, horse-cart +605:hourglass +606:iPod +607:iron, smoothing iron +608:jack-o'-lantern +609:jean, blue jean, denim +610:jeep, landrover +611:jersey, T-shirt, tee shirt +612:jigsaw puzzle +613:jinrikisha, ricksha, rickshaw +614:joystick +615:kimono +616:knee pad +617:knot +618:lab coat, laboratory coat +619:ladle +620:lampshade, lamp shade +621:laptop, laptop computer +622:lawn mower, mower +623:lens cap, lens cover +624:letter opener, paper knife, paperknife +625:library +626:lifeboat +627:lighter, light, igniter, ignitor +628:limousine, limo +629:liner, ocean liner +630:lipstick, lip rouge +631:Loafer +632:lotion +633:loudspeaker, speaker, speaker unit, loudspeaker system, speaker system +634:loupe, jeweler's loupe +635:lumbermill, sawmill +636:magnetic compass +637:mailbag, postbag +638:mailbox, letter box +639:maillot +640:maillot, tank suit +641:manhole cover +642:maraca +643:marimba, xylophone +644:mask +645:matchstick +646:maypole +647:maze, labyrinth +648:measuring cup +649:medicine chest, medicine cabinet +650:megalith, megalithic structure +651:microphone, mike +652:microwave, microwave oven +653:military uniform +654:milk can +655:minibus +656:miniskirt, mini +657:minivan +658:missile +659:mitten +660:mixing bowl +661:mobile home, manufactured home +662:Model T +663:modem +664:monastery +665:monitor +666:moped +667:mortar +668:mortarboard +669:mosque +670:mosquito net +671:motor scooter, scooter +672:mountain bike, all-terrain bike, off-roader +673:mountain tent +674:mouse, computer mouse +675:mousetrap +676:moving van +677:muzzle +678:nail +679:neck brace +680:necklace +681:nipple +682:notebook, notebook computer +683:obelisk +684:oboe, hautboy, hautbois +685:ocarina, sweet potato +686:odometer, hodometer, mileometer, milometer +687:oil filter +688:organ, pipe organ +689:oscilloscope, scope, cathode-ray oscilloscope, CRO +690:overskirt +691:oxcart +692:oxygen mask +693:packet +694:paddle, boat paddle +695:paddlewheel, paddle wheel +696:padlock +697:paintbrush +698:pajama, pyjama, pj's, jammies +699:palace +700:panpipe, pandean pipe, syrinx +701:paper towel +702:parachute, chute +703:parallel bars, bars +704:park bench +705:parking meter +706:passenger car, coach, carriage +707:patio, terrace +708:pay-phone, pay-station +709:pedestal, plinth, footstall +710:pencil box, pencil case +711:pencil sharpener +712:perfume, essence +713:Petri dish +714:photocopier +715:pick, plectrum, plectron +716:pickelhaube +717:picket fence, paling +718:pickup, pickup truck +719:pier +720:piggy bank, penny bank +721:pill bottle +722:pillow +723:ping-pong ball +724:pinwheel +725:pirate, pirate ship +726:pitcher, ewer +727:plane, carpenter's plane, woodworking plane +728:planetarium +729:plastic bag +730:plate rack +731:plow, plough +732:plunger, plumber's helper +733:Polaroid camera, Polaroid Land camera +734:pole +735:police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria +736:poncho +737:pool table, billiard table, snooker table +738:pop bottle, soda bottle +739:pot, flowerpot +740:potter's wheel +741:power drill +742:prayer rug, prayer mat +743:printer +744:prison, prison house +745:projectile, missile +746:projector +747:puck, hockey puck +748:punching bag, punch bag, punching ball, punchball +749:purse +750:quill, quill pen +751:quilt, comforter, comfort, puff +752:racer, race car, racing car +753:racket, racquet +754:radiator +755:radio, wireless +756:radio telescope, radio reflector +757:rain barrel +758:recreational vehicle, RV, R.V. +759:reel +760:reflex camera +761:refrigerator, icebox +762:remote control, remote +763:restaurant, eating house, eating place, eatery +764:revolver, six-gun, six-shooter +765:rifle +766:rocking chair, rocker +767:rotisserie +768:rubber eraser, rubber, pencil eraser +769:rugby ball +770:rule, ruler +771:running shoe +772:safe +773:safety pin +774:saltshaker, salt shaker +775:sandal +776:sarong +777:sax, saxophone +778:scabbard +779:scale, weighing machine +780:school bus +781:schooner +782:scoreboard +783:screen, CRT screen +784:screw +785:screwdriver +786:seat belt, seatbelt +787:sewing machine +788:shield, buckler +789:shoe shop, shoe-shop, shoe store +790:shoji +791:shopping basket +792:shopping cart +793:shovel +794:shower cap +795:shower curtain +796:ski +797:ski mask +798:sleeping bag +799:slide rule, slipstick +800:sliding door +801:slot, one-armed bandit +802:snorkel +803:snowmobile +804:snowplow, snowplough +805:soap dispenser +806:soccer ball +807:sock +808:solar dish, solar collector, solar furnace +809:sombrero +810:soup bowl +811:space bar +812:space heater +813:space shuttle +814:spatula +815:speedboat +816:spider web, spider's web +817:spindle +818:sports car, sport car +819:spotlight, spot +820:stage +821:steam locomotive +822:steel arch bridge +823:steel drum +824:stethoscope +825:stole +826:stone wall +827:stopwatch, stop watch +828:stove +829:strainer +830:streetcar, tram, tramcar, trolley, trolley car +831:stretcher +832:studio couch, day bed +833:stupa, tope +834:submarine, pigboat, sub, U-boat +835:suit, suit of clothes +836:sundial +837:sunglass +838:sunglasses, dark glasses, shades +839:sunscreen, sunblock, sun blocker +840:suspension bridge +841:swab, swob, mop +842:sweatshirt +843:swimming trunks, bathing trunks +844:swing +845:switch, electric switch, electrical switch +846:syringe +847:table lamp +848:tank, army tank, armored combat vehicle, armoured combat vehicle +849:tape player +850:teapot +851:teddy, teddy bear +852:television, television system +853:tennis ball +854:thatch, thatched roof +855:theater curtain, theatre curtain +856:thimble +857:thresher, thrasher, threshing machine +858:throne +859:tile roof +860:toaster +861:tobacco shop, tobacconist shop, tobacconist +862:toilet seat +863:torch +864:totem pole +865:tow truck, tow car, wrecker +866:toyshop +867:tractor +868:trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi +869:tray +870:trench coat +871:tricycle, trike, velocipede +872:trimaran +873:tripod +874:triumphal arch +875:trolleybus, trolley coach, trackless trolley +876:trombone +877:tub, vat +878:turnstile +879:typewriter keyboard +880:umbrella +881:unicycle, monocycle +882:upright, upright piano +883:vacuum, vacuum cleaner +884:vase +885:vault +886:velvet +887:vending machine +888:vestment +889:viaduct +890:violin, fiddle +891:volleyball +892:waffle iron +893:wall clock +894:wallet, billfold, notecase, pocketbook +895:wardrobe, closet, press +896:warplane, military plane +897:washbasin, handbasin, washbowl, lavabo, wash-hand basin +898:washer, automatic washer, washing machine +899:water bottle +900:water jug +901:water tower +902:whiskey jug +903:whistle +904:wig +905:window screen +906:window shade +907:Windsor tie +908:wine bottle +909:wing +910:wok +911:wooden spoon +912:wool, woolen, woollen +913:worm fence, snake fence, snake-rail fence, Virginia fence +914:wreck +915:yawl +916:yurt +917:web site, website, internet site, site +918:comic book +919:crossword puzzle, crossword +920:street sign +921:traffic light, traffic signal, stoplight +922:book jacket, dust cover, dust jacket, dust wrapper +923:menu +924:plate +925:guacamole +926:consomme +927:hot pot, hotpot +928:trifle +929:ice cream, icecream +930:ice lolly, lolly, lollipop, popsicle +931:French loaf +932:bagel, beigel +933:pretzel +934:cheeseburger +935:hotdog, hot dog, red hot +936:mashed potato +937:head cabbage +938:broccoli +939:cauliflower +940:zucchini, courgette +941:spaghetti squash +942:acorn squash +943:butternut squash +944:cucumber, cuke +945:artichoke, globe artichoke +946:bell pepper +947:cardoon +948:mushroom +949:Granny Smith +950:strawberry +951:orange +952:lemon +953:fig +954:pineapple, ananas +955:banana +956:jackfruit, jak, jack +957:custard apple +958:pomegranate +959:hay +960:carbonara +961:chocolate sauce, chocolate syrup +962:dough +963:meat loaf, meatloaf +964:pizza, pizza pie +965:potpie +966:burrito +967:red wine +968:espresso +969:cup +970:eggnog +971:alp +972:bubble +973:cliff, drop, drop-off +974:coral reef +975:geyser +976:lakeside, lakeshore +977:promontory, headland, head, foreland +978:sandbar, sand bar +979:seashore, coast, seacoast, sea-coast +980:valley, vale +981:volcano +982:ballplayer, baseball player +983:groom, bridegroom +984:scuba diver +985:rapeseed +986:daisy +987:yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum +988:corn +989:acorn +990:hip, rose hip, rosehip +991:buckeye, horse chestnut, conker +992:coral fungus +993:agaric +994:gyromitra +995:stinkhorn, carrion fungus +996:earthstar +997:hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa +998:bolete +999:ear, spike, capitulum +1000:toilet tissue, toilet paper, bathroom tissue diff --git a/tests/TfCifar10-Armnn/TfCifar10-Armnn.cpp b/tests/TfCifar10-Armnn/TfCifar10-Armnn.cpp index cfe95095a9..ee2e880951 100644 --- a/tests/TfCifar10-Armnn/TfCifar10-Armnn.cpp +++ b/tests/TfCifar10-Armnn/TfCifar10-Armnn.cpp @@ -13,12 +13,18 @@ int main(int argc, char* argv[]) int retVal = EXIT_FAILURE; try { + using DataType = float; + using DatabaseType = Cifar10Database; + using ParserType = armnnTfParser::ITfParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "cifar10_tf.prototxt", false, "data", "prob", { 0, 1, 2, 4, 7 }, - [](const char* dataDir) { return Cifar10Database(dataDir, true); }, - &inputTensorShape); + [](const char* dataDir, const ModelType&) { + return DatabaseType(dataDir, true); + }, &inputTensorShape); } catch (const std::exception& e) { diff --git a/tests/TfInceptionV3-Armnn/TfInceptionV3-Armnn.cpp b/tests/TfInceptionV3-Armnn/TfInceptionV3-Armnn.cpp index 441b07c9c9..09e70018d3 100644 --- a/tests/TfInceptionV3-Armnn/TfInceptionV3-Armnn.cpp +++ b/tests/TfInceptionV3-Armnn/TfInceptionV3-Armnn.cpp @@ -3,7 +3,7 @@ // See LICENSE file in the project root for full license information. // #include "../InferenceTest.hpp" -#include "../MobileNetDatabase.hpp" +#include "../ImagePreprocessor.hpp" #include "armnnTfParser/ITfParser.hpp" int main(int argc, char* argv[]) @@ -21,11 +21,18 @@ int main(int argc, char* argv[]) armnn::TensorShape inputTensorShape({ 1, 299, 299, 3 }); + using DataType = float; + using DatabaseType = ImagePreprocessor; + using ParserType = armnnTfParser::ITfParser; + using ModelType = InferenceModel; + // Coverity fix: InferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "inception_v3_2016_08_28_frozen_transformed.pb", true, "input", "InceptionV3/Predictions/Reshape_1", { 0, 1, 2, }, - [&imageSet](const char* dataDir) { return MobileNetDatabase(dataDir, 299, 299, imageSet); }, + [&imageSet](const char* dataDir, const ModelType&) { + return DatabaseType(dataDir, 299, 299, imageSet); + }, &inputTensorShape); } catch (const std::exception& e) diff --git a/tests/TfLiteMobilenetQuantized-Armnn/TfLiteMobilenetQuantized-Armnn.cpp b/tests/TfLiteMobilenetQuantized-Armnn/TfLiteMobilenetQuantized-Armnn.cpp new file mode 100644 index 0000000000..7383ab3d94 --- /dev/null +++ b/tests/TfLiteMobilenetQuantized-Armnn/TfLiteMobilenetQuantized-Armnn.cpp @@ -0,0 +1,84 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#include "../InferenceTest.hpp" +#include "../ImagePreprocessor.hpp" +#include "armnnTfLiteParser/ITfLiteParser.hpp" + +using namespace armnnTfLiteParser; + +int main(int argc, char* argv[]) +{ + int retVal = EXIT_FAILURE; + try + { + // Coverity fix: The following code may throw an exception of type std::length_error. + std::vector imageSet = + { + {"Dog.jpg", 209}, + // top five predictions in tensorflow: + // ----------------------------------- + // 209:Labrador retriever 0.949995 + // 160:Rhodesian ridgeback 0.0270182 + // 208:golden retriever 0.0192866 + // 853:tennis ball 0.000470382 + // 239:Greater Swiss Mountain dog 0.000464451 + {"Cat.jpg", 283}, + // top five predictions in tensorflow: + // ----------------------------------- + // 283:tiger cat 0.579016 + // 286:Egyptian cat 0.319676 + // 282:tabby, tabby cat 0.0873346 + // 288:lynx, catamount 0.011163 + // 289:leopard, Panthera pardus 0.000856755 + {"shark.jpg", 3}, + // top five predictions in tensorflow: + // ----------------------------------- + // 3:great white shark, white shark, ... 0.996926 + // 4:tiger shark, Galeocerdo cuvieri 0.00270528 + // 149:killer whale, killer, orca, ... 0.000121848 + // 395:sturgeon 7.78977e-05 + // 5:hammerhead, hammerhead shark 6.44127e-055 + }; + + armnn::TensorShape inputTensorShape({ 1, 224, 224, 3 }); + + using DataType = uint8_t; + using DatabaseType = ImagePreprocessor; + using ParserType = armnnTfLiteParser::ITfLiteParser; + using ModelType = InferenceModel; + + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. + retVal = armnn::test::ClassifierInferenceTestMain( + argc, argv, + "mobilenet_v1_1.0_224_quant.tflite", // model name + true, // model is binary + "input", // input tensor name + "MobilenetV1/Predictions/Reshape_1", // output tensor name + { 0, 1, 2 }, // test images to test with as above + [&imageSet](const char* dataDir, const ModelType & model) { + // we need to get the input quantization parameters from + // the parsed model + auto inputBinding = model.GetInputBindingInfo(); + return DatabaseType( + dataDir, + 224, + 224, + imageSet, + inputBinding.second.GetQuantizationScale(), + inputBinding.second.GetQuantizationOffset()); + }, + &inputTensorShape); + } + catch (const std::exception& e) + { + // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an + // exception of type std::length_error. + // Using stderr instead in this context as there is no point in nesting try-catch blocks here. + std::cerr << "WARNING: " << *argv << ": An error has occurred when running " + "the classifier inference tests: " << e.what() << std::endl; + } + return retVal; +} diff --git a/tests/TfLiteMobilenetQuantized-Armnn/Validation.txt b/tests/TfLiteMobilenetQuantized-Armnn/Validation.txt new file mode 100644 index 0000000000..94a11bdabc --- /dev/null +++ b/tests/TfLiteMobilenetQuantized-Armnn/Validation.txt @@ -0,0 +1,201 @@ +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 +209 +283 +3 \ No newline at end of file diff --git a/tests/TfLiteMobilenetQuantized-Armnn/labels.txt b/tests/TfLiteMobilenetQuantized-Armnn/labels.txt new file mode 100644 index 0000000000..d74ff557dd --- /dev/null +++ b/tests/TfLiteMobilenetQuantized-Armnn/labels.txt @@ -0,0 +1,1001 @@ +0:background +1:tench, Tinca tinca +2:goldfish, Carassius auratus +3:great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias +4:tiger shark, Galeocerdo cuvieri +5:hammerhead, hammerhead shark +6:electric ray, crampfish, numbfish, torpedo +7:stingray +8:cock +9:hen +10:ostrich, Struthio camelus +11:brambling, Fringilla montifringilla +12:goldfinch, Carduelis carduelis +13:house finch, linnet, Carpodacus mexicanus +14:junco, snowbird +15:indigo bunting, indigo finch, indigo bird, Passerina cyanea +16:robin, American robin, Turdus migratorius +17:bulbul +18:jay +19:magpie +20:chickadee +21:water ouzel, dipper +22:kite +23:bald eagle, American eagle, Haliaeetus leucocephalus +24:vulture +25:great grey owl, great gray owl, Strix nebulosa +26:European fire salamander, Salamandra salamandra +27:common newt, Triturus vulgaris +28:eft +29:spotted salamander, Ambystoma maculatum +30:axolotl, mud puppy, Ambystoma mexicanum +31:bullfrog, Rana catesbeiana +32:tree frog, tree-frog +33:tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui +34:loggerhead, loggerhead turtle, Caretta caretta +35:leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea +36:mud turtle +37:terrapin +38:box turtle, box tortoise +39:banded gecko +40:common iguana, iguana, Iguana iguana +41:American chameleon, anole, Anolis carolinensis +42:whiptail, whiptail lizard +43:agama +44:frilled lizard, Chlamydosaurus kingi +45:alligator lizard +46:Gila monster, Heloderma suspectum +47:green lizard, Lacerta viridis +48:African chameleon, Chamaeleo chamaeleon +49:Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis +50:African crocodile, Nile crocodile, Crocodylus niloticus +51:American alligator, Alligator mississipiensis +52:triceratops +53:thunder snake, worm snake, Carphophis amoenus +54:ringneck snake, ring-necked snake, ring snake +55:hognose snake, puff adder, sand viper +56:green snake, grass snake +57:king snake, kingsnake +58:garter snake, grass snake +59:water snake +60:vine snake +61:night snake, Hypsiglena torquata +62:boa constrictor, Constrictor constrictor +63:rock python, rock snake, Python sebae +64:Indian cobra, Naja naja +65:green mamba +66:sea snake +67:horned viper, cerastes, sand viper, horned asp, Cerastes cornutus +68:diamondback, diamondback rattlesnake, Crotalus adamanteus +69:sidewinder, horned rattlesnake, Crotalus cerastes +70:trilobite +71:harvestman, daddy longlegs, Phalangium opilio +72:scorpion +73:black and gold garden spider, Argiope aurantia +74:barn spider, Araneus cavaticus +75:garden spider, Aranea diademata +76:black widow, Latrodectus mactans +77:tarantula +78:wolf spider, hunting spider +79:tick +80:centipede +81:black grouse +82:ptarmigan +83:ruffed grouse, partridge, Bonasa umbellus +84:prairie chicken, prairie grouse, prairie fowl +85:peacock +86:quail +87:partridge +88:African grey, African gray, Psittacus erithacus +89:macaw +90:sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita +91:lorikeet +92:coucal +93:bee eater +94:hornbill +95:hummingbird +96:jacamar +97:toucan +98:drake +99:red-breasted merganser, Mergus serrator +100:goose +101:black swan, Cygnus atratus +102:tusker +103:echidna, spiny anteater, anteater +104:platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus +105:wallaby, brush kangaroo +106:koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus +107:wombat +108:jellyfish +109:sea anemone, anemone +110:brain coral +111:flatworm, platyhelminth +112:nematode, nematode worm, roundworm +113:conch +114:snail +115:slug +116:sea slug, nudibranch +117:chiton, coat-of-mail shell, sea cradle, polyplacophore +118:chambered nautilus, pearly nautilus, nautilus +119:Dungeness crab, Cancer magister +120:rock crab, Cancer irroratus +121:fiddler crab +122:king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica +123:American lobster, Northern lobster, Maine lobster, Homarus americanus +124:spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish +125:crayfish, crawfish, crawdad, crawdaddy +126:hermit crab +127:isopod +128:white stork, Ciconia ciconia +129:black stork, Ciconia nigra +130:spoonbill +131:flamingo +132:little blue heron, Egretta caerulea +133:American egret, great white heron, Egretta albus +134:bittern +135:crane +136:limpkin, Aramus pictus +137:European gallinule, Porphyrio porphyrio +138:American coot, marsh hen, mud hen, water hen, Fulica americana +139:bustard +140:ruddy turnstone, Arenaria interpres +141:red-backed sandpiper, dunlin, Erolia alpina +142:redshank, Tringa totanus +143:dowitcher +144:oystercatcher, oyster catcher +145:pelican +146:king penguin, Aptenodytes patagonica +147:albatross, mollymawk +148:grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus +149:killer whale, killer, orca, grampus, sea wolf, Orcinus orca +150:dugong, Dugong dugon +151:sea lion +152:Chihuahua +153:Japanese spaniel +154:Maltese dog, Maltese terrier, Maltese +155:Pekinese, Pekingese, Peke +156:Shih-Tzu +157:Blenheim spaniel +158:papillon +159:toy terrier +160:Rhodesian ridgeback +161:Afghan hound, Afghan +162:basset, basset hound +163:beagle +164:bloodhound, sleuthhound +165:bluetick +166:black-and-tan coonhound +167:Walker hound, Walker foxhound +168:English foxhound +169:redbone +170:borzoi, Russian wolfhound +171:Irish wolfhound +172:Italian greyhound +173:whippet +174:Ibizan hound, Ibizan Podenco +175:Norwegian elkhound, elkhound +176:otterhound, otter hound +177:Saluki, gazelle hound +178:Scottish deerhound, deerhound +179:Weimaraner +180:Staffordshire bullterrier, Staffordshire bull terrier +181:American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier +182:Bedlington terrier +183:Border terrier +184:Kerry blue terrier +185:Irish terrier +186:Norfolk terrier +187:Norwich terrier +188:Yorkshire terrier +189:wire-haired fox terrier +190:Lakeland terrier +191:Sealyham terrier, Sealyham +192:Airedale, Airedale terrier +193:cairn, cairn terrier +194:Australian terrier +195:Dandie Dinmont, Dandie Dinmont terrier +196:Boston bull, Boston terrier +197:miniature schnauzer +198:giant schnauzer +199:standard schnauzer +200:Scotch terrier, Scottish terrier, Scottie +201:Tibetan terrier, chrysanthemum dog +202:silky terrier, Sydney silky +203:soft-coated wheaten terrier +204:West Highland white terrier +205:Lhasa, Lhasa apso +206:flat-coated retriever +207:curly-coated retriever +208:golden retriever +209:Labrador retriever +210:Chesapeake Bay retriever +211:German short-haired pointer +212:vizsla, Hungarian pointer +213:English setter +214:Irish setter, red setter +215:Gordon setter +216:Brittany spaniel +217:clumber, clumber spaniel +218:English springer, English springer spaniel +219:Welsh springer spaniel +220:cocker spaniel, English cocker spaniel, cocker +221:Sussex spaniel +222:Irish water spaniel +223:kuvasz +224:schipperke +225:groenendael +226:malinois +227:briard +228:kelpie +229:komondor +230:Old English sheepdog, bobtail +231:Shetland sheepdog, Shetland sheep dog, Shetland +232:collie +233:Border collie +234:Bouvier des Flandres, Bouviers des Flandres +235:Rottweiler +236:German shepherd, German shepherd dog, German police dog, alsatian +237:Doberman, Doberman pinscher +238:miniature pinscher +239:Greater Swiss Mountain dog +240:Bernese mountain dog +241:Appenzeller +242:EntleBucher +243:boxer +244:bull mastiff +245:Tibetan mastiff +246:French bulldog +247:Great Dane +248:Saint Bernard, St Bernard +249:Eskimo dog, husky +250:malamute, malemute, Alaskan malamute +251:Siberian husky +252:dalmatian, coach dog, carriage dog +253:affenpinscher, monkey pinscher, monkey dog +254:basenji +255:pug, pug-dog +256:Leonberg +257:Newfoundland, Newfoundland dog +258:Great Pyrenees +259:Samoyed, Samoyede +260:Pomeranian +261:chow, chow chow +262:keeshond +263:Brabancon griffon +264:Pembroke, Pembroke Welsh corgi +265:Cardigan, Cardigan Welsh corgi +266:toy poodle +267:miniature poodle +268:standard poodle +269:Mexican hairless +270:timber wolf, grey wolf, gray wolf, Canis lupus +271:white wolf, Arctic wolf, Canis lupus tundrarum +272:red wolf, maned wolf, Canis rufus, Canis niger +273:coyote, prairie wolf, brush wolf, Canis latrans +274:dingo, warrigal, warragal, Canis dingo +275:dhole, Cuon alpinus +276:African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus +277:hyena, hyaena +278:red fox, Vulpes vulpes +279:kit fox, Vulpes macrotis +280:Arctic fox, white fox, Alopex lagopus +281:grey fox, gray fox, Urocyon cinereoargenteus +282:tabby, tabby cat +283:tiger cat +284:Persian cat +285:Siamese cat, Siamese +286:Egyptian cat +287:cougar, puma, catamount, mountain lion, painter, panther, Felis concolor +288:lynx, catamount +289:leopard, Panthera pardus +290:snow leopard, ounce, Panthera uncia +291:jaguar, panther, Panthera onca, Felis onca +292:lion, king of beasts, Panthera leo +293:tiger, Panthera tigris +294:cheetah, chetah, Acinonyx jubatus +295:brown bear, bruin, Ursus arctos +296:American black bear, black bear, Ursus americanus, Euarctos americanus +297:ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus +298:sloth bear, Melursus ursinus, Ursus ursinus +299:mongoose +300:meerkat, mierkat +301:tiger beetle +302:ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle +303:ground beetle, carabid beetle +304:long-horned beetle, longicorn, longicorn beetle +305:leaf beetle, chrysomelid +306:dung beetle +307:rhinoceros beetle +308:weevil +309:fly +310:bee +311:ant, emmet, pismire +312:grasshopper, hopper +313:cricket +314:walking stick, walkingstick, stick insect +315:cockroach, roach +316:mantis, mantid +317:cicada, cicala +318:leafhopper +319:lacewing, lacewing fly +320:dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk +321:damselfly +322:admiral +323:ringlet, ringlet butterfly +324:monarch, monarch butterfly, milkweed butterfly, Danaus plexippus +325:cabbage butterfly +326:sulphur butterfly, sulfur butterfly +327:lycaenid, lycaenid butterfly +328:starfish, sea star +329:sea urchin +330:sea cucumber, holothurian +331:wood rabbit, cottontail, cottontail rabbit +332:hare +333:Angora, Angora rabbit +334:hamster +335:porcupine, hedgehog +336:fox squirrel, eastern fox squirrel, Sciurus niger +337:marmot +338:beaver +339:guinea pig, Cavia cobaya +340:sorrel +341:zebra +342:hog, pig, grunter, squealer, Sus scrofa +343:wild boar, boar, Sus scrofa +344:warthog +345:hippopotamus, hippo, river horse, Hippopotamus amphibius +346:ox +347:water buffalo, water ox, Asiatic buffalo, Bubalus bubalis +348:bison +349:ram, tup +350:bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis +351:ibex, Capra ibex +352:hartebeest +353:impala, Aepyceros melampus +354:gazelle +355:Arabian camel, dromedary, Camelus dromedarius +356:llama +357:weasel +358:mink +359:polecat, fitch, foulmart, foumart, Mustela putorius +360:black-footed ferret, ferret, Mustela nigripes +361:otter +362:skunk, polecat, wood pussy +363:badger +364:armadillo +365:three-toed sloth, ai, Bradypus tridactylus +366:orangutan, orang, orangutang, Pongo pygmaeus +367:gorilla, Gorilla gorilla +368:chimpanzee, chimp, Pan troglodytes +369:gibbon, Hylobates lar +370:siamang, Hylobates syndactylus, Symphalangus syndactylus +371:guenon, guenon monkey +372:patas, hussar monkey, Erythrocebus patas +373:baboon +374:macaque +375:langur +376:colobus, colobus monkey +377:proboscis monkey, Nasalis larvatus +378:marmoset +379:capuchin, ringtail, Cebus capucinus +380:howler monkey, howler +381:titi, titi monkey +382:spider monkey, Ateles geoffroyi +383:squirrel monkey, Saimiri sciureus +384:Madagascar cat, ring-tailed lemur, Lemur catta +385:indri, indris, Indri indri, Indri brevicaudatus +386:Indian elephant, Elephas maximus +387:African elephant, Loxodonta africana +388:lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens +389:giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca +390:barracouta, snoek +391:eel +392:coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch +393:rock beauty, Holocanthus tricolor +394:anemone fish +395:sturgeon +396:gar, garfish, garpike, billfish, Lepisosteus osseus +397:lionfish +398:puffer, pufferfish, blowfish, globefish +399:abacus +400:abaya +401:academic gown, academic robe, judge's robe +402:accordion, piano accordion, squeeze box +403:acoustic guitar +404:aircraft carrier, carrier, flattop, attack aircraft carrier +405:airliner +406:airship, dirigible +407:altar +408:ambulance +409:amphibian, amphibious vehicle +410:analog clock +411:apiary, bee house +412:apron +413:ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin +414:assault rifle, assault gun +415:backpack, back pack, knapsack, packsack, rucksack, haversack +416:bakery, bakeshop, bakehouse +417:balance beam, beam +418:balloon +419:ballpoint, ballpoint pen, ballpen, Biro +420:Band Aid +421:banjo +422:bannister, banister, balustrade, balusters, handrail +423:barbell +424:barber chair +425:barbershop +426:barn +427:barometer +428:barrel, cask +429:barrow, garden cart, lawn cart, wheelbarrow +430:baseball +431:basketball +432:bassinet +433:bassoon +434:bathing cap, swimming cap +435:bath towel +436:bathtub, bathing tub, bath, tub +437:beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon +438:beacon, lighthouse, beacon light, pharos +439:beaker +440:bearskin, busby, shako +441:beer bottle +442:beer glass +443:bell cote, bell cot +444:bib +445:bicycle-built-for-two, tandem bicycle, tandem +446:bikini, two-piece +447:binder, ring-binder +448:binoculars, field glasses, opera glasses +449:birdhouse +450:boathouse +451:bobsled, bobsleigh, bob +452:bolo tie, bolo, bola tie, bola +453:bonnet, poke bonnet +454:bookcase +455:bookshop, bookstore, bookstall +456:bottlecap +457:bow +458:bow tie, bow-tie, bowtie +459:brass, memorial tablet, plaque +460:brassiere, bra, bandeau +461:breakwater, groin, groyne, mole, bulwark, seawall, jetty +462:breastplate, aegis, egis +463:broom +464:bucket, pail +465:buckle +466:bulletproof vest +467:bullet train, bullet +468:butcher shop, meat market +469:cab, hack, taxi, taxicab +470:caldron, cauldron +471:candle, taper, wax light +472:cannon +473:canoe +474:can opener, tin opener +475:cardigan +476:car mirror +477:carousel, carrousel, merry-go-round, roundabout, whirligig +478:carpenter's kit, tool kit +479:carton +480:car wheel +481:cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM +482:cassette +483:cassette player +484:castle +485:catamaran +486:CD player +487:cello, violoncello +488:cellular telephone, cellular phone, cellphone, cell, mobile phone +489:chain +490:chainlink fence +491:chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour +492:chain saw, chainsaw +493:chest +494:chiffonier, commode +495:chime, bell, gong +496:china cabinet, china closet +497:Christmas stocking +498:church, church building +499:cinema, movie theater, movie theatre, movie house, picture palace +500:cleaver, meat cleaver, chopper +501:cliff dwelling +502:cloak +503:clog, geta, patten, sabot +504:cocktail shaker +505:coffee mug +506:coffeepot +507:coil, spiral, volute, whorl, helix +508:combination lock +509:computer keyboard, keypad +510:confectionery, confectionary, candy store +511:container ship, containership, container vessel +512:convertible +513:corkscrew, bottle screw +514:cornet, horn, trumpet, trump +515:cowboy boot +516:cowboy hat, ten-gallon hat +517:cradle +518:crane +519:crash helmet +520:crate +521:crib, cot +522:Crock Pot +523:croquet ball +524:crutch +525:cuirass +526:dam, dike, dyke +527:desk +528:desktop computer +529:dial telephone, dial phone +530:diaper, nappy, napkin +531:digital clock +532:digital watch +533:dining table, board +534:dishrag, dishcloth +535:dishwasher, dish washer, dishwashing machine +536:disk brake, disc brake +537:dock, dockage, docking facility +538:dogsled, dog sled, dog sleigh +539:dome +540:doormat, welcome mat +541:drilling platform, offshore rig +542:drum, membranophone, tympan +543:drumstick +544:dumbbell +545:Dutch oven +546:electric fan, blower +547:electric guitar +548:electric locomotive +549:entertainment center +550:envelope +551:espresso maker +552:face powder +553:feather boa, boa +554:file, file cabinet, filing cabinet +555:fireboat +556:fire engine, fire truck +557:fire screen, fireguard +558:flagpole, flagstaff +559:flute, transverse flute +560:folding chair +561:football helmet +562:forklift +563:fountain +564:fountain pen +565:four-poster +566:freight car +567:French horn, horn +568:frying pan, frypan, skillet +569:fur coat +570:garbage truck, dustcart +571:gasmask, respirator, gas helmet +572:gas pump, gasoline pump, petrol pump, island dispenser +573:goblet +574:go-kart +575:golf ball +576:golfcart, golf cart +577:gondola +578:gong, tam-tam +579:gown +580:grand piano, grand +581:greenhouse, nursery, glasshouse +582:grille, radiator grille +583:grocery store, grocery, food market, market +584:guillotine +585:hair slide +586:hair spray +587:half track +588:hammer +589:hamper +590:hand blower, blow dryer, blow drier, hair dryer, hair drier +591:hand-held computer, hand-held microcomputer +592:handkerchief, hankie, hanky, hankey +593:hard disc, hard disk, fixed disk +594:harmonica, mouth organ, harp, mouth harp +595:harp +596:harvester, reaper +597:hatchet +598:holster +599:home theater, home theatre +600:honeycomb +601:hook, claw +602:hoopskirt, crinoline +603:horizontal bar, high bar +604:horse cart, horse-cart +605:hourglass +606:iPod +607:iron, smoothing iron +608:jack-o'-lantern +609:jean, blue jean, denim +610:jeep, landrover +611:jersey, T-shirt, tee shirt +612:jigsaw puzzle +613:jinrikisha, ricksha, rickshaw +614:joystick +615:kimono +616:knee pad +617:knot +618:lab coat, laboratory coat +619:ladle +620:lampshade, lamp shade +621:laptop, laptop computer +622:lawn mower, mower +623:lens cap, lens cover +624:letter opener, paper knife, paperknife +625:library +626:lifeboat +627:lighter, light, igniter, ignitor +628:limousine, limo +629:liner, ocean liner +630:lipstick, lip rouge +631:Loafer +632:lotion +633:loudspeaker, speaker, speaker unit, loudspeaker system, speaker system +634:loupe, jeweler's loupe +635:lumbermill, sawmill +636:magnetic compass +637:mailbag, postbag +638:mailbox, letter box +639:maillot +640:maillot, tank suit +641:manhole cover +642:maraca +643:marimba, xylophone +644:mask +645:matchstick +646:maypole +647:maze, labyrinth +648:measuring cup +649:medicine chest, medicine cabinet +650:megalith, megalithic structure +651:microphone, mike +652:microwave, microwave oven +653:military uniform +654:milk can +655:minibus +656:miniskirt, mini +657:minivan +658:missile +659:mitten +660:mixing bowl +661:mobile home, manufactured home +662:Model T +663:modem +664:monastery +665:monitor +666:moped +667:mortar +668:mortarboard +669:mosque +670:mosquito net +671:motor scooter, scooter +672:mountain bike, all-terrain bike, off-roader +673:mountain tent +674:mouse, computer mouse +675:mousetrap +676:moving van +677:muzzle +678:nail +679:neck brace +680:necklace +681:nipple +682:notebook, notebook computer +683:obelisk +684:oboe, hautboy, hautbois +685:ocarina, sweet potato +686:odometer, hodometer, mileometer, milometer +687:oil filter +688:organ, pipe organ +689:oscilloscope, scope, cathode-ray oscilloscope, CRO +690:overskirt +691:oxcart +692:oxygen mask +693:packet +694:paddle, boat paddle +695:paddlewheel, paddle wheel +696:padlock +697:paintbrush +698:pajama, pyjama, pj's, jammies +699:palace +700:panpipe, pandean pipe, syrinx +701:paper towel +702:parachute, chute +703:parallel bars, bars +704:park bench +705:parking meter +706:passenger car, coach, carriage +707:patio, terrace +708:pay-phone, pay-station +709:pedestal, plinth, footstall +710:pencil box, pencil case +711:pencil sharpener +712:perfume, essence +713:Petri dish +714:photocopier +715:pick, plectrum, plectron +716:pickelhaube +717:picket fence, paling +718:pickup, pickup truck +719:pier +720:piggy bank, penny bank +721:pill bottle +722:pillow +723:ping-pong ball +724:pinwheel +725:pirate, pirate ship +726:pitcher, ewer +727:plane, carpenter's plane, woodworking plane +728:planetarium +729:plastic bag +730:plate rack +731:plow, plough +732:plunger, plumber's helper +733:Polaroid camera, Polaroid Land camera +734:pole +735:police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria +736:poncho +737:pool table, billiard table, snooker table +738:pop bottle, soda bottle +739:pot, flowerpot +740:potter's wheel +741:power drill +742:prayer rug, prayer mat +743:printer +744:prison, prison house +745:projectile, missile +746:projector +747:puck, hockey puck +748:punching bag, punch bag, punching ball, punchball +749:purse +750:quill, quill pen +751:quilt, comforter, comfort, puff +752:racer, race car, racing car +753:racket, racquet +754:radiator +755:radio, wireless +756:radio telescope, radio reflector +757:rain barrel +758:recreational vehicle, RV, R.V. +759:reel +760:reflex camera +761:refrigerator, icebox +762:remote control, remote +763:restaurant, eating house, eating place, eatery +764:revolver, six-gun, six-shooter +765:rifle +766:rocking chair, rocker +767:rotisserie +768:rubber eraser, rubber, pencil eraser +769:rugby ball +770:rule, ruler +771:running shoe +772:safe +773:safety pin +774:saltshaker, salt shaker +775:sandal +776:sarong +777:sax, saxophone +778:scabbard +779:scale, weighing machine +780:school bus +781:schooner +782:scoreboard +783:screen, CRT screen +784:screw +785:screwdriver +786:seat belt, seatbelt +787:sewing machine +788:shield, buckler +789:shoe shop, shoe-shop, shoe store +790:shoji +791:shopping basket +792:shopping cart +793:shovel +794:shower cap +795:shower curtain +796:ski +797:ski mask +798:sleeping bag +799:slide rule, slipstick +800:sliding door +801:slot, one-armed bandit +802:snorkel +803:snowmobile +804:snowplow, snowplough +805:soap dispenser +806:soccer ball +807:sock +808:solar dish, solar collector, solar furnace +809:sombrero +810:soup bowl +811:space bar +812:space heater +813:space shuttle +814:spatula +815:speedboat +816:spider web, spider's web +817:spindle +818:sports car, sport car +819:spotlight, spot +820:stage +821:steam locomotive +822:steel arch bridge +823:steel drum +824:stethoscope +825:stole +826:stone wall +827:stopwatch, stop watch +828:stove +829:strainer +830:streetcar, tram, tramcar, trolley, trolley car +831:stretcher +832:studio couch, day bed +833:stupa, tope +834:submarine, pigboat, sub, U-boat +835:suit, suit of clothes +836:sundial +837:sunglass +838:sunglasses, dark glasses, shades +839:sunscreen, sunblock, sun blocker +840:suspension bridge +841:swab, swob, mop +842:sweatshirt +843:swimming trunks, bathing trunks +844:swing +845:switch, electric switch, electrical switch +846:syringe +847:table lamp +848:tank, army tank, armored combat vehicle, armoured combat vehicle +849:tape player +850:teapot +851:teddy, teddy bear +852:television, television system +853:tennis ball +854:thatch, thatched roof +855:theater curtain, theatre curtain +856:thimble +857:thresher, thrasher, threshing machine +858:throne +859:tile roof +860:toaster +861:tobacco shop, tobacconist shop, tobacconist +862:toilet seat +863:torch +864:totem pole +865:tow truck, tow car, wrecker +866:toyshop +867:tractor +868:trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi +869:tray +870:trench coat +871:tricycle, trike, velocipede +872:trimaran +873:tripod +874:triumphal arch +875:trolleybus, trolley coach, trackless trolley +876:trombone +877:tub, vat +878:turnstile +879:typewriter keyboard +880:umbrella +881:unicycle, monocycle +882:upright, upright piano +883:vacuum, vacuum cleaner +884:vase +885:vault +886:velvet +887:vending machine +888:vestment +889:viaduct +890:violin, fiddle +891:volleyball +892:waffle iron +893:wall clock +894:wallet, billfold, notecase, pocketbook +895:wardrobe, closet, press +896:warplane, military plane +897:washbasin, handbasin, washbowl, lavabo, wash-hand basin +898:washer, automatic washer, washing machine +899:water bottle +900:water jug +901:water tower +902:whiskey jug +903:whistle +904:wig +905:window screen +906:window shade +907:Windsor tie +908:wine bottle +909:wing +910:wok +911:wooden spoon +912:wool, woolen, woollen +913:worm fence, snake fence, snake-rail fence, Virginia fence +914:wreck +915:yawl +916:yurt +917:web site, website, internet site, site +918:comic book +919:crossword puzzle, crossword +920:street sign +921:traffic light, traffic signal, stoplight +922:book jacket, dust cover, dust jacket, dust wrapper +923:menu +924:plate +925:guacamole +926:consomme +927:hot pot, hotpot +928:trifle +929:ice cream, icecream +930:ice lolly, lolly, lollipop, popsicle +931:French loaf +932:bagel, beigel +933:pretzel +934:cheeseburger +935:hotdog, hot dog, red hot +936:mashed potato +937:head cabbage +938:broccoli +939:cauliflower +940:zucchini, courgette +941:spaghetti squash +942:acorn squash +943:butternut squash +944:cucumber, cuke +945:artichoke, globe artichoke +946:bell pepper +947:cardoon +948:mushroom +949:Granny Smith +950:strawberry +951:orange +952:lemon +953:fig +954:pineapple, ananas +955:banana +956:jackfruit, jak, jack +957:custard apple +958:pomegranate +959:hay +960:carbonara +961:chocolate sauce, chocolate syrup +962:dough +963:meat loaf, meatloaf +964:pizza, pizza pie +965:potpie +966:burrito +967:red wine +968:espresso +969:cup +970:eggnog +971:alp +972:bubble +973:cliff, drop, drop-off +974:coral reef +975:geyser +976:lakeside, lakeshore +977:promontory, headland, head, foreland +978:sandbar, sand bar +979:seashore, coast, seacoast, sea-coast +980:valley, vale +981:volcano +982:ballplayer, baseball player +983:groom, bridegroom +984:scuba diver +985:rapeseed +986:daisy +987:yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum +988:corn +989:acorn +990:hip, rose hip, rosehip +991:buckeye, horse chestnut, conker +992:coral fungus +993:agaric +994:gyromitra +995:stinkhorn, carrion fungus +996:earthstar +997:hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa +998:bolete +999:ear, spike, capitulum +1000:toilet tissue, toilet paper, bathroom tissue diff --git a/tests/TfMnist-Armnn/TfMnist-Armnn.cpp b/tests/TfMnist-Armnn/TfMnist-Armnn.cpp index bcc3f416cc..e492b9051a 100644 --- a/tests/TfMnist-Armnn/TfMnist-Armnn.cpp +++ b/tests/TfMnist-Armnn/TfMnist-Armnn.cpp @@ -13,11 +13,18 @@ int main(int argc, char* argv[]) int retVal = EXIT_FAILURE; try { + using DataType = float; + using DatabaseType = MnistDatabase; + using ParserType = armnnTfParser::ITfParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "simple_mnist_tf.prototxt", false, "Placeholder", "Softmax", { 0, 1, 2, 3, 4 }, - [](const char* dataDir) { return MnistDatabase(dataDir, true); }, + [](const char* dataDir, const ModelType&) { + return DatabaseType(dataDir, true); + }, &inputTensorShape); } catch (const std::exception& e) diff --git a/tests/TfMobileNet-Armnn/TfMobileNet-Armnn.cpp b/tests/TfMobileNet-Armnn/TfMobileNet-Armnn.cpp index 54759bf88a..cba70c94d3 100644 --- a/tests/TfMobileNet-Armnn/TfMobileNet-Armnn.cpp +++ b/tests/TfMobileNet-Armnn/TfMobileNet-Armnn.cpp @@ -3,7 +3,7 @@ // See LICENSE file in the project root for full license information. // #include "../InferenceTest.hpp" -#include "../MobileNetDatabase.hpp" +#include "../ImagePreprocessor.hpp" #include "armnnTfParser/ITfParser.hpp" int main(int argc, char* argv[]) @@ -15,7 +15,7 @@ int main(int argc, char* argv[]) std::vector imageSet = { {"Dog.jpg", 209}, - // top five predictions in tensorflow: + // Top five predictions in tensorflow: // ----------------------------------- // 209:Labrador retriever 0.949995 // 160:Rhodesian ridgeback 0.0270182 @@ -23,7 +23,7 @@ int main(int argc, char* argv[]) // 853:tennis ball 0.000470382 // 239:Greater Swiss Mountain dog 0.000464451 {"Cat.jpg", 283}, - // top five predictions in tensorflow: + // Top five predictions in tensorflow: // ----------------------------------- // 283:tiger cat 0.579016 // 286:Egyptian cat 0.319676 @@ -31,7 +31,7 @@ int main(int argc, char* argv[]) // 288:lynx, catamount 0.011163 // 289:leopard, Panthera pardus 0.000856755 {"shark.jpg", 3}, - // top five predictions in tensorflow: + // Top five predictions in tensorflow: // ----------------------------------- // 3:great white shark, white shark, ... 0.996926 // 4:tiger shark, Galeocerdo cuvieri 0.00270528 @@ -42,11 +42,21 @@ int main(int argc, char* argv[]) armnn::TensorShape inputTensorShape({ 1, 224, 224, 3 }); + using DataType = float; + using DatabaseType = ImagePreprocessor; + using ParserType = armnnTfParser::ITfParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( - argc, argv, "mobilenet_v1_1.0_224_fp32.pb", true, "input", "output", { 0, 1, 2 }, - [&imageSet](const char* dataDir) { - return MobileNetDatabase( + retVal = armnn::test::ClassifierInferenceTestMain( + argc, argv, + "mobilenet_v1_1.0_224_fp32.pb", // model name + true, // model is binary + "input", "output", // input and output tensor names + { 0, 1, 2 }, // test images to test with as above + [&imageSet](const char* dataDir, const ModelType&) { + // This creates a 224x224x3 NHWC float tensor to pass to Armnn + return DatabaseType( dataDir, 224, 224, diff --git a/tests/TfResNext_Quantized-Armnn/TfResNext_Quantized-Armnn.cpp b/tests/TfResNext_Quantized-Armnn/TfResNext_Quantized-Armnn.cpp index 1e1ede3e68..5817e8bb46 100644 --- a/tests/TfResNext_Quantized-Armnn/TfResNext_Quantized-Armnn.cpp +++ b/tests/TfResNext_Quantized-Armnn/TfResNext_Quantized-Armnn.cpp @@ -3,7 +3,7 @@ // See LICENSE file in the project root for full license information. // #include "../InferenceTest.hpp" -#include "../ImageNetDatabase.hpp" +#include "../CaffePreprocessor.hpp" #include "armnnTfParser/ITfParser.hpp" int main(int argc, char* argv[]) @@ -20,11 +20,18 @@ int main(int argc, char* argv[]) armnn::TensorShape inputTensorShape({ 1, 3, 224, 224 }); + using DataType = float; + using DatabaseType = CaffePreprocessor; + using ParserType = armnnTfParser::ITfParser; + using ModelType = InferenceModel; + // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions. - retVal = armnn::test::ClassifierInferenceTestMain( + retVal = armnn::test::ClassifierInferenceTestMain( argc, argv, "resnext_TF_quantized_for_armnn_team.pb", true, "inputs", "pool1", { 0, 1 }, - [&imageSet](const char* dataDir) { return ImageNetDatabase(dataDir, 224, 224, imageSet); }, + [&imageSet](const char* dataDir, const ModelType &) { + return DatabaseType(dataDir, 224, 224, imageSet); + }, &inputTensorShape); } catch (const std::exception& e) diff --git a/tests/YoloDatabase.cpp b/tests/YoloDatabase.cpp index 4c91384073..71362b2218 100644 --- a/tests/YoloDatabase.cpp +++ b/tests/YoloDatabase.cpp @@ -78,12 +78,12 @@ std::unique_ptr YoloDatabase::GetTestCaseData(unsig const auto& testCaseInputOutput = g_PerTestCaseInputOutput[testCaseId]; const std::string imagePath = m_ImageDir + testCaseInputOutput.first; - // Load test case input image + // Loads test case input image. std::vector imageData; try { InferenceTestImage image(imagePath.c_str()); - image.Resize(YoloImageWidth, YoloImageHeight); + image.Resize(YoloImageWidth, YoloImageHeight, CHECK_LOCATION()); imageData = GetImageDataInArmNnLayoutAsNormalizedFloats(ImageChannelLayout::Rgb, image); } catch (const InferenceTestImageException& e) @@ -92,10 +92,10 @@ std::unique_ptr YoloDatabase::GetTestCaseData(unsig return nullptr; } - // Prepare test case output + // Prepares test case output. std::vector topObjectDetections; topObjectDetections.reserve(1); topObjectDetections.push_back(testCaseInputOutput.second); return std::make_unique(std::move(imageData), std::move(topObjectDetections)); -} \ No newline at end of file +} diff --git a/tests/YoloInferenceTest.hpp b/tests/YoloInferenceTest.hpp index edc4808939..c46cc64b73 100644 --- a/tests/YoloInferenceTest.hpp +++ b/tests/YoloInferenceTest.hpp @@ -105,10 +105,10 @@ public: { for (Boost3dArray::index c = 0; c < numClasses; ++c) { - // Resolved confidence: Class probabilities * scales + // Resolved confidence: class probabilities * scales. const float confidence = classProbabilities[y][x][c] * scales[y][x][s]; - // Resolve bounding box and store + // Resolves bounding box and stores. YoloBoundingBox box; box.m_X = boxes[y][x][s][0]; box.m_Y = boxes[y][x][s][1]; @@ -121,16 +121,16 @@ public: } } - // Sort detected objects by confidence + // Sorts detected objects by confidence. std::sort(detectedObjects.begin(), detectedObjects.end(), [](const YoloDetectedObject& a, const YoloDetectedObject& b) { - // Sort by largest confidence first, then by class + // Sorts by largest confidence first, then by class. return a.m_Confidence > b.m_Confidence || (a.m_Confidence == b.m_Confidence && a.m_Class > b.m_Class); }); - // Check the top N detections + // Checks the top N detections. auto outputIt = detectedObjects.begin(); auto outputEnd = detectedObjects.end(); @@ -138,7 +138,7 @@ public: { if (outputIt == outputEnd) { - // Somehow expected more things to check than detections found by the model + // Somehow expected more things to check than detections found by the model. return TestCaseResult::Abort; } -- cgit v1.2.1