aboutsummaryrefslogtreecommitdiff
path: root/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp')
-rw-r--r--tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp14
1 files changed, 7 insertions, 7 deletions
diff --git a/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp b/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
index e23dbdc9d4..84d5292195 100644
--- a/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
+++ b/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
@@ -23,7 +23,7 @@ int main(int argc, char* argv[])
{"shark.jpg", 669},
};
- armnn::TensorShape inputTensorShape({ 2, 224, 224, 3 });
+ armnn::TensorShape inputTensorShape({ 1, 224, 224, 3 });
using DataType = uint8_t;
using DatabaseType = ImagePreprocessor<DataType>;
@@ -34,11 +34,11 @@ int main(int argc, char* argv[])
retVal = armnn::test::ClassifierInferenceTestMain<DatabaseType,
ParserType>(
argc, argv,
- "vgg_16_u8.tflite", // model name
- true, // model is binary
- "content_vgg/concat", // input tensor name
- "content_vgg/prob", // output tensor name
- { 0, 1, 2 }, // test images to test with as above
+ "vgg_16_u8_batch1.tflite", // model name
+ true, // model is binary
+ "content_vgg/concat", // input tensor name
+ "content_vgg/prob", // output tensor name
+ { 0, 1, 2 }, // test images to test with as above
[&imageSet](const char* dataDir, const ModelType & model) {
// we need to get the input quantization parameters from
// the parsed model
@@ -53,7 +53,7 @@ int main(int argc, char* argv[])
{{0, 0, 0}},
{{1, 1, 1}},
DatabaseType::DataFormat::NCHW,
- 2);
+ 1);
},
&inputTensorShape);
}