aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiCongLi <sicong.li@arm.com>2021-10-29 15:05:49 +0100
committerSiCong Li <sicong.li@arm.com>2021-11-01 14:29:51 +0000
commiteb8bd81a625f0f87080dbde55b434362ad57324a (patch)
treefda1de0843be17266388d0d137908f392a7f694e
parent1af5416917268692fcd4b34b1d7ffebd3a2aea8a (diff)
downloadComputeLibrary-eb8bd81a625f0f87080dbde55b434362ad57324a.tar.gz
Fix dst "widening" validation
* Auto-initialize the dst tensor before checking for PostOp shape compliance so that we catch the invalid case of "widening" dst tensor shape * Rework post op validate test cases to be more readable Partially resolves: COMPMID-4435 Change-Id: I79943994182942f962e4d59a7fa0d6f017ae9ac7 Signed-off-by: SiCongLi <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6548 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/experimental/IPostOp.h18
-rw-r--r--src/core/CL/CLUtils.cpp10
-rw-r--r--src/core/experimental/PostOp.h10
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp4
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp272
-rw-r--r--utils/TypePrinter.h2
6 files changed, 176 insertions, 140 deletions
diff --git a/arm_compute/core/experimental/IPostOp.h b/arm_compute/core/experimental/IPostOp.h
index cd6b8fc4cc..4fac4c88e9 100644
--- a/arm_compute/core/experimental/IPostOp.h
+++ b/arm_compute/core/experimental/IPostOp.h
@@ -44,7 +44,7 @@ using PostOpTypeSequence = std::vector<PostOpType>;
* It contains:
* 1. The attributes of the original operator.
* 2. Any additional tensor argument.
- * 3. The postion of the previous op's dst tensor in its argument list ( @ref prev_dst_pos )
+ * 3. The position of the previous op's dst tensor in its argument list ( @ref prev_dst_pos )
*
* For example, a series of chained ops:
*
@@ -62,8 +62,16 @@ using PostOpTypeSequence = std::vector<PostOpType>;
* post op1: relu(act_info, prev_dst_pos = 0)
* post op2: div(div_info, src1, prev_dst_pos = 1)
*
- * NOTE: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type
- * NOTE: If TensorRelatedT points to a resource, IPostOp assumes that resource is valid throughout its lifetime
+ * @note: On Broadcasting
+ * For n-ary post ops, the tensor arguments must not "widen" the dst tensor of the main op
+ * For example, for a dst of shape [14, 1, 34]:
+ * * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dim 0
+ * * post_op_arg1 = [14, 1, 34] is allowed: no broadcast
+ * * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dims 0 and 1
+ * * post_op_arg1 = [14, 15, 34] is NOT allowed: broadcast widens the dst tensor
+ *
+ * @note: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type
+ * @note: If TensorRelatedT points to a resource, IPostOp assumes that resource is valid throughout its lifetime
* and the lifetime of its copies. This is almost guaranteed as IPostOp is only meant to be used at configure time
* after the ITensor or ITensorInfo objects are already constructed
*/
@@ -71,7 +79,7 @@ template <typename TensorRelatedT>
struct IPostOp
{
/** Get the arity of the post op
- * NOTE: that this is one fewer than the arity of the original op, because we implicitly pass the previous op's dst
+ * @note: that this is one fewer than the arity of the original op, because we implicitly pass the previous op's dst
* tensor as one of the arguments
*/
size_t arity() const
@@ -88,7 +96,7 @@ struct IPostOp
virtual std::vector<TensorRelatedT *> arguments() = 0;
virtual std::vector<const TensorRelatedT *> arguments() const = 0;
/** Clone method used in cases where PostOps are owned by unique_ptr
- * NOTE: This performs a shallow copy of the TensorRelatedT if TensorRelatedT points to a resource
+ * @note: This performs a shallow copy of the TensorRelatedT if TensorRelatedT points to a resource
*/
virtual std::unique_ptr<IPostOp<TensorRelatedT>> clone() const = 0;
virtual ~IPostOp()
diff --git a/src/core/CL/CLUtils.cpp b/src/core/CL/CLUtils.cpp
index 1da970e705..748b0f55a1 100644
--- a/src/core/CL/CLUtils.cpp
+++ b/src/core/CL/CLUtils.cpp
@@ -85,16 +85,24 @@ PostOpCLKernelUtils::PostOpCLKernelUtils(const Config &supported_config)
bool PostOpCLKernelUtils::are_post_op_shapes_compliant(const ITensorInfo *dst, const experimental::PostOpList<ITensorInfo *> &post_ops)
{
- // All post ops must be elementwise and must not alter the shape of the original dst tensor after broadcasting
for(const auto &op : post_ops.get_list())
{
for(const auto &tensor : op->arguments())
{
const TensorShape &out_shape = TensorShape::broadcast_shape(dst->tensor_shape(), (*tensor)->tensor_shape());
+ // All post ops must be elementwise and must not alter the shape of the original dst tensor after broadcasting
if(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0))
{
return false;
}
+ // NOTE: Kernel limitation: currently only the following broadcasting types are supported:
+ // 1. Post op arg is scalar, broadcast in both X and Y
+ // 2. Post op arg is of shape: Y=1, X=N, broadcast only in Y
+ // This means this case: Post op arg is of shape: Y=M, X=1, broadcast only in X, is NOT supported
+ if(dst->dimension(0) > 1 && dst->dimension(1) > 1 && (*tensor)->dimension(0) == 1 && (*tensor)->dimension(1) > 1)
+ {
+ return false;
+ }
}
}
return true;
diff --git a/src/core/experimental/PostOp.h b/src/core/experimental/PostOp.h
index 64414d2050..7d62bd95e1 100644
--- a/src/core/experimental/PostOp.h
+++ b/src/core/experimental/PostOp.h
@@ -79,9 +79,9 @@ template <typename TensorRelatedT>
struct PostOpEltwiseAdd : public IPostOp<TensorRelatedT>
{
public:
- PostOpEltwiseAdd(TensorRelatedT addend, int prev_op_arg_pos, ConvertPolicy policy)
+ PostOpEltwiseAdd(TensorRelatedT addend, int prev_dst_pos, ConvertPolicy policy)
: _addend{ addend },
- _prev_op_arg_pos{ prev_op_arg_pos },
+ _prev_dst_pos{ prev_dst_pos },
_policy{ policy }
{
}
@@ -93,7 +93,7 @@ public:
PostOpEltwiseAdd &operator=(PostOpEltwiseAdd &&) = default;
int prev_dst_pos() const override
{
- return _prev_op_arg_pos;
+ return _prev_dst_pos;
}
PostOpType type() const override
{
@@ -112,7 +112,7 @@ public:
return std::make_unique<PostOpEltwiseAdd<TensorRelatedT>>(*this);
}
TensorRelatedT _addend;
- int _prev_op_arg_pos;
+ int _prev_dst_pos;
ConvertPolicy _policy;
};
@@ -135,7 +135,7 @@ PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTens
case PostOpType::Eltwise_Add:
{
const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwiseAdd<FromTensorT> *>(post_op.get());
- transformed_post_ops.template push_back_op<PostOpEltwiseAdd<ToTensorT>>(transform_arg(_post_op->_addend), _post_op->_prev_op_arg_pos, _post_op->_policy);
+ transformed_post_ops.template push_back_op<PostOpEltwiseAdd<ToTensorT>>(transform_arg(_post_op->_addend), _post_op->_prev_dst_pos, _post_op->_policy);
break;
}
default:
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
index 4b28e2badc..8ee72d3f03 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
@@ -182,11 +182,11 @@ void ClGemmMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info));
-
// dst tensor auto initialization if not yet initialized
auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info)));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info));
+
auto padding_info = get_padding_info({ src0, src1, src2, dst });
_reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0;
_use_dummy_work_items = preferred_dummy_work_items_support(CLKernelLibrary::get().get_device());
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index b13c380470..a598780bf6 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -185,11 +185,6 @@ const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { fa
/** Post Ops */
using PostOpArgBroadcast = CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<float>::PostOpArgBroadcast;
-experimental::PostOpList<PostOpArgBroadcast> empty_post_ops()
-{
- return experimental::PostOpList<PostOpArgBroadcast>{};
-}
-
experimental::PostOpList<PostOpArgBroadcast> post_ops_1()
{
experimental::PostOpList<PostOpArgBroadcast> post_ops{};
@@ -221,20 +216,6 @@ experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
ConvertPolicy::SATURATE);
return post_ops;
}
-experimental::PostOpList<PostOpArgBroadcast> invalid_post_ops_1()
-{
- experimental::PostOpList<PostOpArgBroadcast> post_ops{};
- post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
- std::make_tuple(true, true, false), // If broadcast in dims 0, 1 and 2
- 1,
- ConvertPolicy::SATURATE);
- post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
- std::make_tuple(false, true, false), // If broadcast in dims 0, 1 and 2
- 0,
- ConvertPolicy::SATURATE);
- return post_ops;
-}
-
/** Different Post Op Lists */
const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_1(),
@@ -242,6 +223,42 @@ const auto post_op_lists = framework::dataset::make("post_op_lists", {
post_ops_3(),
} );
+bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
+{
+ const auto lhs_info = GEMMLHSMatrixInfo(4,4,1,false,true);
+ const auto rhs_info = GEMMRHSMatrixInfo(4,4,1,true,true,false);
+
+ // Create TensorInfo for post op arguments
+ TensorInfo input0_info(TensorShape(k, m, batch), 1, data_type);
+ TensorInfo input1_info(TensorShape(n, k, batch), 1, data_type);
+ TensorInfo input2_info(TensorShape(n), 1, data_type);
+ TensorInfo output_info(TensorShape(n, m, batch), 1, data_type);
+
+ const TensorInfo reshaped_input0_info = input0_info.clone()->set_tensor_shape(misc::shape_calculator::compute_lhs_reshaped_shape(input0_info, lhs_info));
+ const TensorInfo reshaped_input1_info = input1_info.clone()->set_tensor_shape(misc::shape_calculator::compute_rhs_reshaped_shape(input1_info, rhs_info));
+
+ GEMMKernelInfo gemm_info(m, n, k, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
+ false /**< reinterpret the input as 3D */,
+ true /**< Flag used to broadcast the bias addition */,
+ false /**< wider accumm */,
+ false /**< has pad y */,
+ ActivationLayerInfo::ActivationFunction::IDENTITY,
+ 1 /**< Multiplication factor for the width of the 1xW transposed block */,
+ 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
+ lhs_info,
+ rhs_info,
+ 0 /**< Offset to be added to each element of the matrix A */,
+ 0 /**< Offset to be added to each element of the matrix B */,
+ post_ops);
+ return bool(ClGemmMatrixMultiplyReshapedKernel::validate(&reshaped_input0_info.clone()->set_is_resizable(true),
+ &reshaped_input1_info.clone()->set_is_resizable(true),
+ &input2_info.clone()->set_is_resizable(true),
+ &output_info.clone()->set_is_resizable(true),1.f,1.f,
+ lhs_info,
+ rhs_info,
+ gemm_info));
+}
+
} // namespace
TEST_SUITE(CL)
@@ -406,116 +423,119 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zi
rhs_info,
gemm_info)) == expected, framework::LogLevel::ERRORS);
}
-DATA_TEST_CASE(ValidateFusedPosOps, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(zip(
- framework::dataset::make("Input0Info", { TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // OK. Empty post ops
- TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // Invalid post op sequences
- TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // OK. Supported post ops
-
- }),
- framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
- TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
- TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
-
- })),
- framework::dataset::make("Input2Info", { TensorInfo(TensorShape(21U), 1, DataType::F32),
- TensorInfo(TensorShape(21U), 1, DataType::F32),
- TensorInfo(TensorShape(21U), 1, DataType::F32),
-
- })),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
- TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
- TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
-
- })),
- framework::dataset::make("LHSMInfo",{
- GEMMLHSMatrixInfo(4,4,1,false,true),
- GEMMLHSMatrixInfo(4,4,1,false,true),
- GEMMLHSMatrixInfo(4,4,1,false,true),
-
- })),
- framework::dataset::make("RHSMInfo",{
- GEMMRHSMatrixInfo(4,4,1,true,true,false),
- GEMMRHSMatrixInfo(4,4,1,true,true,false),
- GEMMRHSMatrixInfo(4,4,1,true,true,false),
-
-
- })),
-
-
- framework::dataset::make("GEMMInfo",{
- GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
- 21 /**<N Number of RHS columns*/,
- 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
- false /**< reinterpret the input as 3D */,
- true /**< Flag used to broadcast the bias addition */,
- false /**< wider accumm */,
- false /**< has pad y */,
- ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
- 1 /**< Multiplication factor for the width of the 1xW transposed block */,
- 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
- GEMMLHSMatrixInfo(4,4,1,false,true),
- GEMMRHSMatrixInfo(4,4,1,true,true,false),
- 0 /**< Offset to be added to each element of the matrix A */,
- 0 /**< Offset to be added to each element of the matrix B */),
+TEST_SUITE(ValidateFusedPostOpsConfigs)
+TEST_SUITE(Invalid)
+TEST_CASE(UnsupportedPostOpSequence, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const unsigned int m = 17;
+ const unsigned int n = 1;
+ const unsigned int k = 13;
+ const unsigned int batch = 2;
+ TensorShape post_op_arg0_shape(n, m, batch);
+ TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
+ auto post_op_arg1_info = post_op_arg_info.clone();
+
+ // Unsupported sequence of post ops
+ experimental::PostOpList<ITensorInfo*> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>(
+ &post_op_arg_info,
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>(
+ post_op_arg1_info.get(),
+ 0,
+ ConvertPolicy::SATURATE);
- GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
- 21 /**<N Number of RHS columns*/,
- 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
- false /**< reinterpret the input as 3D */,
- true /**< Flag used to broadcast the bias addition */,
- false /**< wider accumm */,
- false /**< has pad y */,
- ActivationLayerInfo::ActivationFunction::IDENTITY,
- 1 /**< Multiplication factor for the width of the 1xW transposed block */,
- 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
- GEMMLHSMatrixInfo(4,4,1,false,true),
- GEMMRHSMatrixInfo(4,4,1,true,true,false),
- 0 /**< Offset to be added to each element of the matrix A */,
- 0 /**< Offset to be added to each element of the matrix B */),
- GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
- 21 /**<N Number of RHS columns*/,
- 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
- false /**< reinterpret the input as 3D */,
- true /**< Flag used to broadcast the bias addition */,
- false /**< wider accumm */,
- false /**< has pad y */,
- ActivationLayerInfo::ActivationFunction::IDENTITY,
- 1 /**< Multiplication factor for the width of the 1xW transposed block */,
- 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
- GEMMLHSMatrixInfo(4,4,1,false,true),
- GEMMRHSMatrixInfo(4,4,1,true,true,false),
- 0 /**< Offset to be added to each element of the matrix A */,
- 0 /**< Offset to be added to each element of the matrix B */),
- })),
- framework::dataset::make("PostOps",{
- empty_post_ops(),
- invalid_post_ops_1(),
- post_ops_1(),
- })),
- framework::dataset::make("Expected", { true, false, true})),
- input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, post_ops, expected)
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(OutputWidened, framework::DatasetMode::ALL)
{
- // Create TensorInfo for post op arguments
- std::vector<TensorInfo> post_op_tensor_infos;
- auto populated_post_ops = experimental::transform_post_op_list_arguments<PostOpArgBroadcast, ITensorInfo*>(post_ops,
- [&output_info, &post_op_tensor_infos](auto broadcast){
- post_op_tensor_infos.emplace_back(TensorShape{
- std::get<0>(broadcast) ? 1 : output_info.dimension(0),
- std::get<1>(broadcast) ? 1 : output_info.dimension(1),
- std::get<2>(broadcast) ? 1 : output_info.dimension(2)
- }, 1, output_info.data_type());
- return &post_op_tensor_infos.back();
- });
- GEMMKernelInfo gemm_info_with_post_ops(std::move(gemm_info));
- gemm_info_with_post_ops.post_ops = populated_post_ops;
- ARM_COMPUTE_EXPECT(bool(ClGemmMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
- &input1_info.clone()->set_is_resizable(true),
- &input2_info.clone()->set_is_resizable(true),
- &output_info.clone()->set_is_resizable(true),1.f,1.f,
- lhs_info,
- rhs_info,
- gemm_info_with_post_ops)) == expected, framework::LogLevel::ERRORS);
+ // Invalid broadcast: post op tensors "widen" the output tensor
+ const auto data_type = DataType::F32;
+ const unsigned int m = 17;
+ const unsigned int n = 1;
+ const unsigned int k = 13;
+ const unsigned int batch = 2;
+ TensorShape post_op_arg_shape(n + 4, m, batch); // output's X dimension (n) is "widened", which is not allowed
+ TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
+ experimental::PostOpList<ITensorInfo*> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(BroadcastInXDimOnly, framework::DatasetMode::ALL)
+{
+ // Invalid broadcast: post op tensors broadcast in the first dimension (X) only
+ const auto data_type = DataType::F32;
+ const unsigned int m = 22;
+ const unsigned int n = 16;
+ const unsigned int k = 15;
+ const unsigned int batch = 3;
+ TensorShape post_op_arg_shape(1, m, batch);
+ TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
+ experimental::PostOpList<ITensorInfo*> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_SUITE_END() // Invalid
+TEST_SUITE(Valid)
+TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const unsigned int m = 22;
+ const unsigned int n = 16;
+ const unsigned int k = 15;
+ const unsigned int batch = 3;
+ experimental::PostOpList<ITensorInfo*> post_ops{};
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
+}
+TEST_CASE(BroadcastInYDimOnly, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const unsigned int m = 22;
+ const unsigned int n = 16;
+ const unsigned int k = 15;
+ const unsigned int batch = 3;
+ TensorShape post_op_arg_shape(n, 1, batch);
+ TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
+ experimental::PostOpList<ITensorInfo*> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
+}
+TEST_CASE(BroadcastInBothXandYDims, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const unsigned int m = 22;
+ const unsigned int n = 16;
+ const unsigned int k = 15;
+ const unsigned int batch = 3;
+ TensorShape post_op_arg_shape(1, 1, batch);
+ TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
+ experimental::PostOpList<ITensorInfo*> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
+}
+TEST_CASE(BroadcastInAllDims, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const unsigned int m = 22;
+ const unsigned int n = 16;
+ const unsigned int k = 15;
+ const unsigned int batch = 3;
+ TensorShape post_op_arg_shape(1, 1, 1);
+ TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
+ experimental::PostOpList<ITensorInfo*> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
}
+TEST_SUITE_END() // Valid
+TEST_SUITE_END() // ValidateFusedPostOps
TEST_SUITE(Float)
TEST_SUITE(FP32)
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 64694f0e7c..30ba667b95 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -193,7 +193,7 @@ inline ::std::ostream &operator<<(::std::ostream &os, const experimental::IPostO
{
os << "<";
os << post_op.type() << ",";
- os << "prev_op_arg_pos=" << post_op.prev_dst_pos() << ",";
+ os << "prev_dst_pos=" << post_op.prev_dst_pos() << ",";
switch(post_op.type())
{
case experimental::PostOpType::Activation: