5 #include "../InferenceTest.hpp" 6 #include "../ImagePreprocessor.hpp" 9 #include <cxxopts/cxxopts.hpp> 16 std::ifstream read(filename);
17 std::vector<ImageSet> imageSet;
21 for (std::string line; std::getline(read, line);)
23 stringstream ss(line);
24 std::string image_name;
26 getline(ss, image_name,
' ');
27 getline(ss, label,
' ');
28 imageSet.push_back(
ImageSet(image_name, std::stoi(label)));
34 imageSet.push_back(
ImageSet(
"Dog.jpg", 209));
42 imageSet.push_back(
ImageSet(
"Cat.jpg", 283));
50 imageSet.push_back(
ImageSet(
"shark.jpg", 3));
64 cxxopts::Options options(
"TfLiteMobilenetQuantized-Armnn",
"Validation Options");
70 .allow_unrecognised_options()
73 "Filename of a text file where in each line contains an image " 74 "filename and the correct label the network should predict when fed that image",
75 cxxopts::value<std::string>(fileName));
77 auto result = options.parse(argc, argv);
79 catch (
const cxxopts::OptionException& e)
81 std::cerr << e.what() << std::endl;
84 catch (
const std::exception& e)
86 std::cerr <<
"Fatal internal error: [" << e.what() <<
"]" << std::endl;
94 int main(
int argc,
char* argv[])
96 int retVal = EXIT_FAILURE;
101 std::vector<ImageSet> imageSet =
ParseDataset(labels_file);
102 std::vector<unsigned int> indices(imageSet.size());
103 std::generate(indices.begin(), indices.end(), [n = 0] ()
mutable {
return n++; });
116 "mobilenet_v1_1.0_224_quant.tflite",
119 "MobilenetV1/Predictions/Reshape_1",
121 [&imageSet](
const char* dataDir,
const ModelType &) {
133 catch (
const std::exception& e)
138 std::cerr <<
"WARNING: " << *argv <<
": An error has occurred when running " 139 "the classifier inference tests: " << e.what() << std::endl;
std::vector< ImageSet > ParseDataset(const std::string &filename)
std::string GetLabelsFilenameFromOptions(int argc, char *argv[])
int main(int argc, char *argv[])
std::pair< const std::string, unsigned int > ImageSet
Tf requires RGB images, normalized in range [0, 1] and resized using Bilinear algorithm.
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)