aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/WeightsReshapeFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/WeightsReshapeFixture.h')
-rw-r--r--tests/validation/fixtures/WeightsReshapeFixture.h16
1 files changed, 13 insertions, 3 deletions
diff --git a/tests/validation/fixtures/WeightsReshapeFixture.h b/tests/validation/fixtures/WeightsReshapeFixture.h
index 0b3e76d677..7c7214acac 100644
--- a/tests/validation/fixtures/WeightsReshapeFixture.h
+++ b/tests/validation/fixtures/WeightsReshapeFixture.h
@@ -45,7 +45,7 @@ namespace validation
using namespace arm_compute::misc::shape_calculator;
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class WeightsReshapeValidationFixture : public framework::Fixture
+class WeightsReshapeOpValidationFixture : public framework::Fixture
{
public:
template <typename...>
@@ -73,7 +73,7 @@ protected:
// Create and configure function
FunctionType weights_reshape_func;
- weights_reshape_func.configure(&src, (has_bias ? &bias : nullptr), &dst, num_groups);
+ weights_reshape_func.configure(src.info(), (has_bias ? bias.info() : nullptr), dst.info(), num_groups);
ARM_COMPUTE_ASSERT(src.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
@@ -99,8 +99,18 @@ protected:
fill(AccessorType(bias), 1);
}
+ arm_compute::ITensorPack pack =
+ {
+ { arm_compute::TensorType::ACL_SRC, &src },
+ { arm_compute::TensorType::ACL_DST, &dst }
+ };
+
+ if(has_bias)
+ {
+ pack.add_const_tensor(arm_compute::TensorType::ACL_BIAS, &bias);
+ }
// Compute function
- weights_reshape_func.run();
+ weights_reshape_func.run(pack);
return dst;
}