diff options
Diffstat (limited to 'delegate/src/test/SoftmaxTestHelper.hpp')
-rw-r--r-- | delegate/src/test/SoftmaxTestHelper.hpp | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/delegate/src/test/SoftmaxTestHelper.hpp b/delegate/src/test/SoftmaxTestHelper.hpp index b3086bb0cb..bd32c212e9 100644 --- a/delegate/src/test/SoftmaxTestHelper.hpp +++ b/delegate/src/test/SoftmaxTestHelper.hpp @@ -167,4 +167,26 @@ void SoftmaxTest(tflite::BuiltinOperator softmaxOperatorCode, } } + +/// Convenience function to run softmax and log-softmax test cases +/// \param operatorCode tflite::BuiltinOperator_SOFTMAX or tflite::BuiltinOperator_LOG_SOFTMAX +/// \param backends armnn backends to target +/// \param beta multiplicative parameter to the softmax function +/// \param expectedOutput to be checked against transformed input +void SoftmaxTestCase(tflite::BuiltinOperator operatorCode, + std::vector<armnn::BackendId> backends, float beta, std::vector<float> expectedOutput) { + std::vector<float> input = { + 1.0, 2.5, 3.0, 4.5, 5.0, + -1.0, -2.5, -3.0, -4.5, -5.0}; + std::vector<int32_t> shape = {2, 5}; + + SoftmaxTest(operatorCode, + tflite::TensorType_FLOAT32, + backends, + shape, + input, + expectedOutput, + beta); +} + } // anonymous namespace |