diff options
Diffstat (limited to 'tests/model_objects/AlexNet.h')
-rw-r--r-- | tests/model_objects/AlexNet.h | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tests/model_objects/AlexNet.h b/tests/model_objects/AlexNet.h index c9fd448d5d..45622e2118 100644 --- a/tests/model_objects/AlexNet.h +++ b/tests/model_objects/AlexNet.h @@ -24,6 +24,8 @@ #ifndef __ARM_COMPUTE_TEST_MODEL_OBJECTS_ALEXNET_H__ #define __ARM_COMPUTE_TEST_MODEL_OBJECTS_ALEXNET_H__ +#include "arm_compute/runtime/Tensor.h" + #include "tests/AssetsLibrary.h" #include "tests/Globals.h" #include "tests/Utils.h" @@ -149,7 +151,7 @@ public: b[6]->allocator()->init(TensorInfo(TensorShape(4096U), 1, dt, fixed_point_position)); b[7]->allocator()->init(TensorInfo(TensorShape(1000U), 1, dt, fixed_point_position)); - if(_batches > 1) + if(_batches > 1 && std::is_same<TensorType, Tensor>::value) { w[5]->allocator()->init(TensorInfo(TensorShape(9216U * dt_size, 4096U / dt_size), 1, dt, fixed_point_position)); w[6]->allocator()->init(TensorInfo(TensorShape(4096U * dt_size, 4096U / dt_size), 1, dt, fixed_point_position)); |