diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/TypePrinter.h | 82 |
1 files changed, 72 insertions, 10 deletions
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index 23e73f6a9e..f47943aa77 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -473,14 +473,14 @@ inline ::std::ostream &operator<<(::std::ostream &os, const BoundingBoxTransform } #if defined(ARM_COMPUTE_ENABLE_BF16) -inline ::std::ostream &operator<<(::std::ostream &os, const bfloat16& v) +inline ::std::ostream &operator<<(::std::ostream &os, const bfloat16 &v) { std::stringstream str; str << v; os << str.str(); return os; } -#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ /** Formatted output of the BoundingBoxTransformInfo type. * @@ -3252,19 +3252,81 @@ inline std::string to_string(const Conv3dInfo &conv3d_info) return str.str(); } -inline ::std::ostream &operator<<(::std::ostream &os, const arm_gemm::WeightFormat &wf) +/** Formatted output of the arm_compute::WeightFormat type. + * + * @param[in] wf arm_compute::WeightFormat Type to output. + * + * @return Formatted string. + */ +inline std::string to_string(const WeightFormat wf) { - os << arm_gemm::to_string(wf); - return os; +#define __CASE_WEIGHT_FORMAT(wf) \ +case WeightFormat::wf: \ + return #wf; + switch(wf) + { + __CASE_WEIGHT_FORMAT(UNSPECIFIED) + __CASE_WEIGHT_FORMAT(ANY) + __CASE_WEIGHT_FORMAT(OHWI) + __CASE_WEIGHT_FORMAT(OHWIo2) + __CASE_WEIGHT_FORMAT(OHWIo4) + __CASE_WEIGHT_FORMAT(OHWIo8) + __CASE_WEIGHT_FORMAT(OHWIo16) + __CASE_WEIGHT_FORMAT(OHWIo32) + __CASE_WEIGHT_FORMAT(OHWIo64) + __CASE_WEIGHT_FORMAT(OHWIo128) + __CASE_WEIGHT_FORMAT(OHWIo4i2) + __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo8i2) + __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo16i2) + __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo32i2) + __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo64i2) + __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo4i4) + __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo8i4) + __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo16i4) + __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo32i4) + __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo64i4) + __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo2i8) + __CASE_WEIGHT_FORMAT(OHWIo4i8) + __CASE_WEIGHT_FORMAT(OHWIo8i8) + __CASE_WEIGHT_FORMAT(OHWIo16i8) + __CASE_WEIGHT_FORMAT(OHWIo32i8) + __CASE_WEIGHT_FORMAT(OHWIo64i8) + default: + return "invalid value"; + } +#undef __CASE_WEIGHT_FORMAT } -inline std::string to_string(const arm_gemm::WeightFormat wf) + +/** Formatted output of the arm_compute::WeightFormat type. + * + * @param[out] os Output stream. + * @param[in] wf WeightFormat to output. + * + * @return Modified output stream. + */ +inline ::std::ostream &operator<<(::std::ostream &os, const arm_compute::WeightFormat &wf) { - std::stringstream str; - str << wf; - return str.str(); + os << to_string(wf); + return os; } -inline std::string to_string(const std::tuple<TensorShape, TensorShape, arm_gemm::WeightFormat> values) +/** Formatted output of the std::tuple<TensorShape, TensorShape, arm_compute::WeightFormat> tuple. + * + * @param[in] values tuple of input and output tensor shapes and WeightFormat used. + * + * @return Formatted string. + */ +inline std::string to_string(const std::tuple<TensorShape, TensorShape, arm_compute::WeightFormat> values) { std::stringstream str; str << "[Input shape = " << std::get<0>(values); |