10 #include <boost/assert.hpp> 11 #include <boost/core/ignore_unused.hpp> 12 #include <boost/numeric/conversion/cast.hpp> 13 #include <boost/test/tools/floating_point_comparison.hpp> 20 template<
typename Model>
24 DeepSpeechV1TestCase(
Model& model,
25 unsigned int testCaseId,
26 const DeepSpeechV1TestCaseData& testCaseData)
29 { testCaseData.m_InputData.m_InputSeq,
30 testCaseData.m_InputData.m_StateH,
31 testCaseData.m_InputData.m_StateC},
32 { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
33 , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
34 , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
35 testCaseData.m_ExpectedOutputData.m_StateC})
40 boost::ignore_unused(options);
41 const std::vector<float>& output1 = boost::get<std::vector<float>>(this->
GetOutputs()[0]);
42 BOOST_ASSERT(output1.size() == k_OutputSize1);
44 const std::vector<float>& output2 = boost::get<std::vector<float>>(this->
GetOutputs()[1]);
45 BOOST_ASSERT(output2.size() == k_OutputSize2);
47 const std::vector<float>& output3 = boost::get<std::vector<float>>(this->
GetOutputs()[2]);
48 BOOST_ASSERT(output3.size() == k_OutputSize3);
51 for (
unsigned int j = 0u; j < output1.size(); j++)
53 if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
56 " is incorrect at" << j;
61 for (
unsigned int j = 0u; j < output2.size(); j++)
63 if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j]))
71 for (
unsigned int j = 0u; j < output3.size(); j++)
73 if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j]))
85 static constexpr
unsigned int k_OutputSize1 = 464u;
86 static constexpr
unsigned int k_OutputSize2 = 2048u;
87 static constexpr
unsigned int k_OutputSize3 = 2048u;
89 boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
90 LstmInput m_ExpectedOutputs;
93 template <
typename Model>
97 template <
typename TConstructModelCallable>
98 explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
99 : m_ConstructModel(constructModel)
102 virtual void AddCommandLineOptions(boost::program_options::options_description& options)
override 104 namespace po = boost::program_options;
106 options.add_options()
107 (
"input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(),
108 "Path to directory containing test data for m_InputSeq");
109 options.add_options()
110 (
"prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(),
111 "Path to directory containing test data for m_PrevStateH");
112 options.add_options()
113 (
"prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(),
114 "Path to directory containing test data for m_PrevStateC");
115 options.add_options()
116 (
"logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(),
117 "Path to directory containing test data for m_Logits");
118 options.add_options()
119 (
"new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(),
120 "Path to directory containing test data for m_NewStateH");
121 options.add_options()
122 (
"new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(),
123 "Path to directory containing test data for m_NewStateC");
126 Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
161 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
166 m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
167 m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
168 m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
177 std::unique_ptr<IInferenceTestCase> GetTestCase(
unsigned int testCaseId)
override 179 std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
185 return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
189 typename Model::CommandLineOptions m_ModelCommandLineOptions;
191 typename Model::CommandLineOptions)> m_ConstructModel;
192 std::unique_ptr<Model> m_Model;
194 std::string m_InputSeqDir;
195 std::string m_PrevStateCDir;
196 std::string m_PrevStateHDir;
197 std::string m_LogitsDir;
198 std::string m_NewStateCDir;
199 std::string m_NewStateHDir;
201 std::unique_ptr<DeepSpeechV1Database> m_Database;
#define ARMNN_LOG(severity)
The test completed without any errors.
const std::vector< TContainer > & GetOutputs() const
virtual TestCaseResult ProcessResult(const InferenceTestOptions &options)=0
unsigned int GetTestCaseId() const
armnn::Runtime::CreationOptions::ExternalProfilingOptions options
bool ValidateDirectory(std::string &dir)