aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/test/SoftmaxTestHelper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/test/SoftmaxTestHelper.hpp')
-rw-r--r--delegate/src/test/SoftmaxTestHelper.hpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/delegate/src/test/SoftmaxTestHelper.hpp b/delegate/src/test/SoftmaxTestHelper.hpp
index b3086bb0cb..bd32c212e9 100644
--- a/delegate/src/test/SoftmaxTestHelper.hpp
+++ b/delegate/src/test/SoftmaxTestHelper.hpp
@@ -167,4 +167,26 @@ void SoftmaxTest(tflite::BuiltinOperator softmaxOperatorCode,
}
}
+
+/// Convenience function to run softmax and log-softmax test cases
+/// \param operatorCode tflite::BuiltinOperator_SOFTMAX or tflite::BuiltinOperator_LOG_SOFTMAX
+/// \param backends armnn backends to target
+/// \param beta multiplicative parameter to the softmax function
+/// \param expectedOutput to be checked against transformed input
+void SoftmaxTestCase(tflite::BuiltinOperator operatorCode,
+ std::vector<armnn::BackendId> backends, float beta, std::vector<float> expectedOutput) {
+ std::vector<float> input = {
+ 1.0, 2.5, 3.0, 4.5, 5.0,
+ -1.0, -2.5, -3.0, -4.5, -5.0};
+ std::vector<int32_t> shape = {2, 5};
+
+ SoftmaxTest(operatorCode,
+ tflite::TensorType_FLOAT32,
+ backends,
+ shape,
+ input,
+ expectedOutput,
+ beta);
+}
+
} // anonymous namespace