From 88d5b22eb5574d8b564474df2c758d222b3b5547 Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Fri, 6 Apr 2018 12:24:55 +0100 Subject: COMPMID-1035 - Add ResneXt50 as a graph example Change-Id: I42f0e7dab38e45b5eecfe6858eaecee8939c8585 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129291 Reviewed-by: Georgios Pinitas Reviewed-by: Anthony Barbier Tested-by: Jenkins --- utils/GraphUtils.cpp | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) (limited to 'utils/GraphUtils.cpp') diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp index 145e44950b..0edb6f2a56 100644 --- a/utils/GraphUtils.cpp +++ b/utils/GraphUtils.cpp @@ -129,6 +129,45 @@ bool DummyAccessor::access_tensor(ITensor &tensor) return ret; } +NumPyAccessor::NumPyAccessor(std::string npy_path, TensorShape shape, DataType data_type, std::ostream &output_stream) + : _npy_tensor(), _filename(std::move(npy_path)), _output_stream(output_stream) +{ + NumPyBinLoader loader(_filename); + + TensorInfo info(shape, 1, data_type); + _npy_tensor.allocator()->init(info); + _npy_tensor.allocator()->allocate(); + + loader.access_tensor(_npy_tensor); +} + +template +void NumPyAccessor::access_numpy_tensor(ITensor &tensor) +{ + const int num_elements = tensor.info()->total_size(); + int num_mismatches = utils::compare_tensor(tensor, _npy_tensor); + float percentage_mismatches = static_cast(num_mismatches) / num_elements; + + _output_stream << "Results: " << 100.f - (percentage_mismatches * 100) << " % matches with the provided output[" << _filename << "]." << std::endl; +} + +bool NumPyAccessor::access_tensor(ITensor &tensor) +{ + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32); + ARM_COMPUTE_ERROR_ON(_npy_tensor.info()->dimension(0) != tensor.info()->dimension(0)); + + switch(tensor.info()->data_type()) + { + case DataType::F32: + access_numpy_tensor(tensor); + break; + default: + ARM_COMPUTE_ERROR("NOT SUPPORTED!"); + } + + return false; +} + PPMAccessor::PPMAccessor(std::string ppm_path, bool bgr, std::unique_ptr preprocessor) : _ppm_path(std::move(ppm_path)), _bgr(bgr), _preprocessor(std::move(preprocessor)) { -- cgit v1.2.1