aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/SoftmaxLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/SoftmaxLayerFixture.h')
-rw-r--r--tests/validation/fixtures/SoftmaxLayerFixture.h15
1 files changed, 4 insertions, 11 deletions
diff --git a/tests/validation/fixtures/SoftmaxLayerFixture.h b/tests/validation/fixtures/SoftmaxLayerFixture.h
index 29a3ed2cd0..30356d648d 100644
--- a/tests/validation/fixtures/SoftmaxLayerFixture.h
+++ b/tests/validation/fixtures/SoftmaxLayerFixture.h
@@ -32,7 +32,6 @@
#include "tests/IAccessor.h"
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
-#include "tests/validation/reference/LogSoftmaxLayer.h"
#include "tests/validation/reference/SoftmaxLayer.h"
#include <random>
@@ -52,8 +51,8 @@ public:
{
_quantization_info = quantization_info;
- _target = compute_target(shape, data_type, quantization_info, beta, axis);
_reference = compute_reference(shape, data_type, quantization_info, beta, axis);
+ _target = compute_target(shape, data_type, quantization_info, beta, axis);
}
protected:
@@ -62,7 +61,7 @@ protected:
{
if(!is_data_type_quantized(tensor.data_type()))
{
- std::uniform_real_distribution<> distribution(-1000.f, 1000.f);
+ std::uniform_real_distribution<> distribution(-10.f, 10.f);
library->fill(tensor, distribution, 0);
}
else // data type is quantized_asymmetric (signed or unsigned)
@@ -111,14 +110,7 @@ protected:
// Fill reference
fill(src);
- if(IS_LOG)
- {
- return reference::log_softmax_layer<T>(src, beta, axis);
- }
- else
- {
- return reference::softmax_layer<T>(src, beta, axis);
- }
+ return reference::softmax_layer<T>(src, beta, axis, IS_LOG);
}
TensorType _target{};
@@ -155,6 +147,7 @@ public:
axis);
}
};
+
} // namespace validation
} // namespace test
} // namespace arm_compute