diff options
Diffstat (limited to 'utils/GraphUtils.cpp')
-rw-r--r-- | utils/GraphUtils.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp index b714c55136..26ea02a9ff 100644 --- a/utils/GraphUtils.cpp +++ b/utils/GraphUtils.cpp @@ -153,24 +153,29 @@ NumPyAccessor::NumPyAccessor(std::string npy_path, TensorShape shape, DataType d } template <typename T> -void NumPyAccessor::access_numpy_tensor(ITensor &tensor) +void NumPyAccessor::access_numpy_tensor(ITensor &tensor, T tolerance) { const int num_elements = tensor.info()->tensor_shape().total_size(); - int num_mismatches = utils::compare_tensor<T>(tensor, _npy_tensor); + int num_mismatches = utils::compare_tensor<T>(tensor, _npy_tensor, tolerance); float percentage_mismatches = static_cast<float>(num_mismatches) / num_elements; _output_stream << "Results: " << 100.f - (percentage_mismatches * 100) << " % matches with the provided output[" << _filename << "]." << std::endl; + _output_stream << " " << num_elements - num_mismatches << " out of " << num_elements << " matches with the provided output[" << _filename << "]." << std::endl + << std::endl; } bool NumPyAccessor::access_tensor(ITensor &tensor) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32, DataType::QASYMM8); ARM_COMPUTE_ERROR_ON(_npy_tensor.info()->dimension(0) != tensor.info()->dimension(0)); switch(tensor.info()->data_type()) { + case DataType::QASYMM8: + access_numpy_tensor<qasymm8_t>(tensor, 0); + break; case DataType::F32: - access_numpy_tensor<float>(tensor); + access_numpy_tensor<float>(tensor, 0.0001f); break; default: ARM_COMPUTE_ERROR("NOT SUPPORTED!"); |