aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/LoadModel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test/LoadModel.cpp')
-rw-r--r--src/armnnTfLiteParser/test/LoadModel.cpp76
1 files changed, 38 insertions, 38 deletions
diff --git a/src/armnnTfLiteParser/test/LoadModel.cpp b/src/armnnTfLiteParser/test/LoadModel.cpp
index 1afb5f12e5..e09de68c72 100644
--- a/src/armnnTfLiteParser/test/LoadModel.cpp
+++ b/src/armnnTfLiteParser/test/LoadModel.cpp
@@ -2,7 +2,7 @@
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#include <boost/test/unit_test.hpp>
+
#include "ParserFlatbuffersFixture.hpp"
#include "../TfLiteParser.hpp"
@@ -13,8 +13,8 @@ using ModelPtr = TfLiteParserImpl::ModelPtr;
using SubgraphPtr = TfLiteParserImpl::SubgraphPtr;
using OperatorPtr = TfLiteParserImpl::OperatorPtr;
-BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
-
+TEST_SUITE("TensorflowLiteParser_LoadModel")
+{
struct LoadModelFixture : public ParserFlatbuffersFixture
{
explicit LoadModelFixture()
@@ -137,53 +137,53 @@ struct LoadModelFixture : public ParserFlatbuffersFixture
const std::vector<tflite::BuiltinOperator>& opcodes,
size_t subgraphs, const std::string desc, size_t buffers)
{
- BOOST_CHECK(model);
- BOOST_CHECK_EQUAL(version, model->version);
- BOOST_CHECK_EQUAL(opcodeSize, model->operator_codes.size());
+ CHECK(model);
+ CHECK_EQ(version, model->version);
+ CHECK_EQ(opcodeSize, model->operator_codes.size());
CheckBuiltinOperators(opcodes, model->operator_codes);
- BOOST_CHECK_EQUAL(subgraphs, model->subgraphs.size());
- BOOST_CHECK_EQUAL(desc, model->description);
- BOOST_CHECK_EQUAL(buffers, model->buffers.size());
+ CHECK_EQ(subgraphs, model->subgraphs.size());
+ CHECK_EQ(desc, model->description);
+ CHECK_EQ(buffers, model->buffers.size());
}
void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result)
{
- BOOST_CHECK_EQUAL(expectedOperators.size(), result.size());
+ CHECK_EQ(expectedOperators.size(), result.size());
for (size_t i = 0; i < expectedOperators.size(); i++)
{
- BOOST_CHECK_EQUAL(expectedOperators[i], result[i]->builtin_code);
+ CHECK_EQ(expectedOperators[i], result[i]->builtin_code);
}
}
void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs,
const std::vector<int32_t>& outputs, size_t operators, const std::string& name)
{
- BOOST_CHECK(subgraph);
- BOOST_CHECK_EQUAL(tensors, subgraph->tensors.size());
- BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end());
- BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
- subgraph->outputs.begin(), subgraph->outputs.end());
- BOOST_CHECK_EQUAL(operators, subgraph->operators.size());
- BOOST_CHECK_EQUAL(name, subgraph->name);
+ CHECK(subgraph);
+ CHECK_EQ(tensors, subgraph->tensors.size());
+ CHECK(std::equal(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end()));
+ CHECK(std::equal(outputs.begin(), outputs.end(),
+ subgraph->outputs.begin(), subgraph->outputs.end()));
+ CHECK_EQ(operators, subgraph->operators.size());
+ CHECK_EQ(name, subgraph->name);
}
void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode, const std::vector<int32_t>& inputs,
const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType,
tflite::CustomOptionsFormat custom_options_format)
{
- BOOST_CHECK(operatorPtr);
- BOOST_CHECK_EQUAL(opcode, operatorPtr->opcode_index);
- BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(),
- operatorPtr->inputs.begin(), operatorPtr->inputs.end());
- BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
- operatorPtr->outputs.begin(), operatorPtr->outputs.end());
- BOOST_CHECK_EQUAL(optionType, operatorPtr->builtin_options.type);
- BOOST_CHECK_EQUAL(custom_options_format, operatorPtr->custom_options_format);
+ CHECK(operatorPtr);
+ CHECK_EQ(opcode, operatorPtr->opcode_index);
+ CHECK(std::equal(inputs.begin(), inputs.end(),
+ operatorPtr->inputs.begin(), operatorPtr->inputs.end()));
+ CHECK(std::equal(outputs.begin(), outputs.end(),
+ operatorPtr->outputs.begin(), operatorPtr->outputs.end()));
+ CHECK_EQ(optionType, operatorPtr->builtin_options.type);
+ CHECK_EQ(custom_options_format, operatorPtr->custom_options_format);
}
};
-BOOST_FIXTURE_TEST_CASE(LoadModelFromBinary, LoadModelFixture)
+TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromBinary")
{
TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
m_GraphBinary.size());
@@ -197,14 +197,14 @@ BOOST_FIXTURE_TEST_CASE(LoadModelFromBinary, LoadModelFixture)
tflite::CustomOptionsFormat_FLEXBUFFERS);
}
-BOOST_FIXTURE_TEST_CASE(LoadModelFromFile, LoadModelFixture)
+TEST_CASE_FIXTURE(LoadModelFixture, "LoadModelFromFile")
{
using namespace fs;
fs::path fname = armnnUtils::Filesystem::NamedTempFile("Armnn-tfLite-LoadModelFromFile-TempFile.csv");
bool saved = flatbuffers::SaveFile(fname.c_str(),
reinterpret_cast<char *>(m_GraphBinary.data()),
m_GraphBinary.size(), true);
- BOOST_CHECK_MESSAGE(saved, "Cannot save test file");
+ CHECK_MESSAGE(saved, "Cannot save test file");
TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromFile(fname.c_str());
CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
@@ -218,26 +218,26 @@ BOOST_FIXTURE_TEST_CASE(LoadModelFromFile, LoadModelFixture)
remove(fname);
}
-BOOST_AUTO_TEST_CASE(LoadNullBinary)
+TEST_CASE("LoadNullBinary")
{
- BOOST_CHECK_THROW(TfLiteParserImpl::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException);
+ CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException);
}
-BOOST_AUTO_TEST_CASE(LoadInvalidBinary)
+TEST_CASE("LoadInvalidBinary")
{
std::string testData = "invalid data";
- BOOST_CHECK_THROW(TfLiteParserImpl::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData),
+ CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData),
testData.length()), armnn::ParseException);
}
-BOOST_AUTO_TEST_CASE(LoadFileNotFound)
+TEST_CASE("LoadFileNotFound")
{
- BOOST_CHECK_THROW(TfLiteParserImpl::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException);
+ CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException);
}
-BOOST_AUTO_TEST_CASE(LoadNullPtrFile)
+TEST_CASE("LoadNullPtrFile")
{
- BOOST_CHECK_THROW(TfLiteParserImpl::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
+ CHECK_THROWS_AS(TfLiteParserImpl::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
}
-BOOST_AUTO_TEST_SUITE_END()
+}