aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test/RefCreateWorkloadTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test/RefCreateWorkloadTests.cpp')
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp18
1 files changed, 18 insertions, 0 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 4293ef54f3..fae8d0cdd4 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -486,6 +486,24 @@ TEST_CASE("RefCreateFullyConnectedWithBlobWorkloadTest")
TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale));
}
+TEST_CASE("CreateFullyConnectedWorkloadWeightsBiasesAsInputsFloat32")
+{
+ Graph graph;
+ RefWorkloadFactory factory = GetFactory();
+
+ auto workload =
+ CreateFullyConnectedWorkloadWeightsBiasesAsInputsTest<RefFullyConnectedWorkload,
+ armnn::DataType::Float32>(factory, graph);
+
+ // Checks that outputs and inputs are as we expect them (see definition of CreateFullyConnectedWorkloadTest).
+ float inputsQScale = 0.0f;
+ float outputQScale = 0.0f;
+ CheckInputsOutput(std::move(workload),
+ TensorInfo({ 3, 1, 4, 5 }, armnn::DataType::Float32, inputsQScale),
+ TensorInfo({ 7, 20 }, armnn::DataType::Float32, inputsQScale),
+ TensorInfo({ 3, 7 }, armnn::DataType::Float32, outputQScale));
+}
+
template <typename FullyConnectedWorkloadType, armnn::DataType DataType>
static void RefCreateFullyConnectedWorkloadTest()
{