aboutsummaryrefslogtreecommitdiff
path: root/tests/model_objects
diff options
context:
space:
mode:
Diffstat (limited to 'tests/model_objects')
-rw-r--r--tests/model_objects/AlexNet.h4
1 files changed, 3 insertions, 1 deletions
diff --git a/tests/model_objects/AlexNet.h b/tests/model_objects/AlexNet.h
index c9fd448d5d..45622e2118 100644
--- a/tests/model_objects/AlexNet.h
+++ b/tests/model_objects/AlexNet.h
@@ -24,6 +24,8 @@
#ifndef __ARM_COMPUTE_TEST_MODEL_OBJECTS_ALEXNET_H__
#define __ARM_COMPUTE_TEST_MODEL_OBJECTS_ALEXNET_H__
+#include "arm_compute/runtime/Tensor.h"
+
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
#include "tests/Utils.h"
@@ -149,7 +151,7 @@ public:
b[6]->allocator()->init(TensorInfo(TensorShape(4096U), 1, dt, fixed_point_position));
b[7]->allocator()->init(TensorInfo(TensorShape(1000U), 1, dt, fixed_point_position));
- if(_batches > 1)
+ if(_batches > 1 && std::is_same<TensorType, Tensor>::value)
{
w[5]->allocator()->init(TensorInfo(TensorShape(9216U * dt_size, 4096U / dt_size), 1, dt, fixed_point_position));
w[6]->allocator()->init(TensorInfo(TensorShape(4096U * dt_size, 4096U / dt_size), 1, dt, fixed_point_position));