diff options
Diffstat (limited to 'utils/Utils.h')
-rw-r--r-- | utils/Utils.h | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/utils/Utils.h b/utils/Utils.h index cadba3a088..6cb71fd3ba 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -924,6 +924,43 @@ void init_sgemm_output(T &dst, T &src0, T &src1, arm_compute::DataType dt) * @return The free memory in kB */ uint64_t get_mem_free_from_meminfo(); + +/** Compare to tensor + * + * @param[in] tensor1 First tensor to be compared. + * @param[in] tensor2 Second tensor to be compared. + * + * @return The number of mismatches + */ +template <typename T> +int compare_tensor(ITensor &tensor1, ITensor &tensor2) +{ + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&tensor1, &tensor2); + ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(&tensor1, &tensor2); + + int num_mismatches = 0; + Window window; + window.use_tensor_dimensions(tensor1.info()->tensor_shape()); + + map(tensor1, true); + map(tensor2, true); + Iterator itensor1(&tensor1, window); + Iterator itensor2(&tensor2, window); + + execute_window_loop(window, [&](const Coordinates & id) + { + if(std::abs(*reinterpret_cast<T *>(itensor1.ptr()) - *reinterpret_cast<T *>(itensor2.ptr())) > 0.00001) + { + ++num_mismatches; + } + }, + itensor1, itensor2); + + unmap(itensor1); + unmap(itensor2); + + return num_mismatches; +} } // namespace utils } // namespace arm_compute #endif /* __UTILS_UTILS_H__*/ |