aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2018-10-26 10:24:14 +0100
committernattapat.chaimanowong <nattapat.chaimanowong@arm.com>2018-10-26 12:38:34 +0000
commitd8eee59735526ead6b87343c3ed9069e682b6e8c (patch)
tree44b994c98857b28674fb4911cb20489aaaecc437
parentd4dfa684941a21314b70593d01b0fc2167eebad4 (diff)
downloadarmnn-d8eee59735526ead6b87343c3ed9069e682b6e8c.tar.gz
IVGCVSW-2029 Fix fully connected layer support in TfLite Parser and implement test for TfLite VGG16 quantized
Change-Id: I2061f62f62684b963fa0f090718f1dcffe5c93ce
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp1
-rw-r--r--src/armnnTfLiteParser/test/FullyConnected.cpp4
-rw-r--r--tests/CMakeLists.txt6
-rw-r--r--tests/ImagePreprocessor.cpp8
-rw-r--r--tests/ImagePreprocessor.hpp5
-rw-r--r--tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp68
6 files changed, 88 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 8b1d3e6bc4..5e0d4b7d6a 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1231,6 +1231,7 @@ void TfLiteParser::ParseFullyConnected(size_t subgraphIndex, size_t operatorInde
FullyConnectedDescriptor desc;
desc.m_BiasEnabled = false;
+ desc.m_TransposeWeightMatrix = true;
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp
index 2853fe96ab..14ca57c2ab 100644
--- a/src/armnnTfLiteParser/test/FullyConnected.cpp
+++ b/src/armnnTfLiteParser/test/FullyConnected.cpp
@@ -118,7 +118,7 @@ struct FullyConnectedWithNoBiasFixture : FullyConnectedFixture
FullyConnectedWithNoBiasFixture()
: FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape
"[ 1, 1 ]", // outputShape
- "[ 4, 1 ]", // filterShape
+ "[ 1, 4 ]", // filterShape
"[ 2, 3, 4, 5 ]") // filterData
{}
};
@@ -136,7 +136,7 @@ struct FullyConnectedWithBiasFixture : FullyConnectedFixture
FullyConnectedWithBiasFixture()
: FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape
"[ 1, 1 ]", // outputShape
- "[ 4, 1 ]", // filterShape
+ "[ 1, 4 ]", // filterShape
"[ 2, 3, 4, 5 ]", // filterData
"[ 1 ]", // biasShape
"[ 10, 0, 0, 0 ]" ) // biasData
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 97f21154fe..d6475c263b 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -163,6 +163,12 @@ if (BUILD_TF_LITE_PARSER)
ImagePreprocessor.hpp
ImagePreprocessor.cpp)
TfLiteParserTest(TfLiteMobilenetQuantized-Armnn "${TfLiteMobilenetQuantized-Armnn_sources}")
+
+ set(TfLiteVGG16Quantized-Armnn_sources
+ TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
+ ImagePreprocessor.hpp
+ ImagePreprocessor.cpp)
+ TfLiteParserTest(TfLiteVGG16Quantized-Armnn "${TfLiteVGG16Quantized-Armnn_sources}")
endif()
if (BUILD_ONNX_PARSER)
diff --git a/tests/ImagePreprocessor.cpp b/tests/ImagePreprocessor.cpp
index 1f29cffe65..8ceedd2d04 100644
--- a/tests/ImagePreprocessor.cpp
+++ b/tests/ImagePreprocessor.cpp
@@ -33,10 +33,16 @@ unsigned int ImagePreprocessor<TDataType>::GetLabelAndResizedImageAsFloat(unsign
InferenceTestImage::ResizingMethods::BilinearAndNormalized,
m_Mean, m_Stddev);
+ // duplicate data across the batch
+ for (unsigned int i = 1; i < m_BatchSize; i++)
+ {
+ result.insert( result.end(), result.begin(), result.begin() + GetNumImageElements() );
+ }
+
if (m_DataFormat == DataFormat::NCHW)
{
const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
- armnn::TensorShape dstShape({1, 3, m_Height, m_Width});
+ armnn::TensorShape dstShape({m_BatchSize, 3, m_Height, m_Width});
std::vector<float> tempImage(result.size());
armnnUtils::Permute<float>(dstShape, NHWCToArmNN, result.data(), tempImage.data());
result.swap(tempImage);
diff --git a/tests/ImagePreprocessor.hpp b/tests/ImagePreprocessor.hpp
index 9add6d86ad..d77113c6d9 100644
--- a/tests/ImagePreprocessor.hpp
+++ b/tests/ImagePreprocessor.hpp
@@ -37,10 +37,12 @@ public:
int32_t offset=0,
const std::array<float, 3> mean={{0, 0, 0}},
const std::array<float, 3> stddev={{1, 1, 1}},
- DataFormat dataFormat=DataFormat::NHWC)
+ DataFormat dataFormat=DataFormat::NHWC,
+ unsigned int batchSize=1)
: m_BinaryDirectory(binaryFileDirectory)
, m_Height(height)
, m_Width(width)
+ , m_BatchSize(batchSize)
, m_Scale(scale)
, m_Offset(offset)
, m_ImageSet(imageSet)
@@ -61,6 +63,7 @@ private:
std::string m_BinaryDirectory;
unsigned int m_Height;
unsigned int m_Width;
+ unsigned int m_BatchSize;
// Quantization parameters
float m_Scale;
int32_t m_Offset;
diff --git a/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp b/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
new file mode 100644
index 0000000000..1313d2d01a
--- /dev/null
+++ b/tests/TfLiteVGG16Quantized-Armnn/TfLiteVGG16Quantized-Armnn.cpp
@@ -0,0 +1,68 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "../InferenceTest.hpp"
+#include "../ImagePreprocessor.hpp"
+#include "armnnTfLiteParser/ITfLiteParser.hpp"
+
+using namespace armnnTfLiteParser;
+
+int main(int argc, char* argv[])
+{
+ int retVal = EXIT_FAILURE;
+ try
+ {
+ // Coverity fix: The following code may throw an exception of type std::length_error.
+ std::vector<ImageSet> imageSet =
+ {
+ // Class number in probability print out offset by 1000 due to batch size fix
+ {"Dog.jpg", 669},
+ {"Cat.jpg", 669},
+ {"shark.jpg", 669},
+ };
+
+ armnn::TensorShape inputTensorShape({ 2, 224, 224, 3 });
+
+ using DataType = uint8_t;
+ using DatabaseType = ImagePreprocessor<DataType>;
+ using ParserType = armnnTfLiteParser::ITfLiteParser;
+ using ModelType = InferenceModel<ParserType, DataType>;
+
+ // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions.
+ retVal = armnn::test::ClassifierInferenceTestMain<DatabaseType,
+ ParserType>(
+ argc, argv,
+ "vgg_16_u8.tflite", // model name
+ true, // model is binary
+ "content_vgg/concat", // input tensor name
+ "content_vgg/prob", // output tensor name
+ { 0, 1, 2 }, // test images to test with as above
+ [&imageSet](const char* dataDir, const ModelType & model) {
+ // we need to get the input quantization parameters from
+ // the parsed model
+ auto inputBinding = model.GetInputBindingInfo();
+ return DatabaseType(
+ dataDir,
+ 224,
+ 224,
+ imageSet,
+ inputBinding.second.GetQuantizationScale(),
+ inputBinding.second.GetQuantizationOffset(),
+ {{0, 0, 0}},
+ {{1, 1, 1}},
+ DatabaseType::DataFormat::NCHW,
+ 2);
+ },
+ &inputTensorShape);
+ }
+ catch (const std::exception& e)
+ {
+ // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
+ // exception of type std::length_error.
+ // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
+ std::cerr << "WARNING: " << *argv << ": An error has occurred when running "
+ "the classifier inference tests: " << e.what() << std::endl;
+ }
+ return retVal;
+}