aboutsummaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-06-27 12:34:20 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:10 +0000
commit7908de7ea914c1c968af01880357791a300483be (patch)
tree591445f3e351c4f8176c21cd11a7f31b8f574151 /utils
parent9e454f36db4e13a3794289cb13e2cefbfbde1047 (diff)
downloadComputeLibrary-7908de7ea914c1c968af01880357791a300483be.tar.gz
COMPMID-1309: Add result validation on full network validation
Change-Id: I5a7e2b198593c782eb812dd3f013ee2b91dc895f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/137627 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'utils')
-rw-r--r--utils/GraphUtils.cpp115
-rw-r--r--utils/GraphUtils.h45
2 files changed, 159 insertions, 1 deletions
diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp
index 4db053cf9f..d94dcb0d86 100644
--- a/utils/GraphUtils.cpp
+++ b/utils/GraphUtils.cpp
@@ -268,6 +268,121 @@ bool ValidationInputAccessor::access_tensor(arm_compute::ITensor &tensor)
return ret;
}
+ValidationOutputAccessor::ValidationOutputAccessor(const std::string &image_list,
+ size_t top_n,
+ std::ostream &output_stream,
+ unsigned int start,
+ unsigned int end)
+ : _results(), _output_stream(output_stream), _top_n(top_n), _offset(0), _positive_samples(0)
+{
+ ARM_COMPUTE_ERROR_ON_MSG(start > end, "Invalid validation range!");
+ ARM_COMPUTE_ERROR_ON(top_n == 0);
+
+ std::ifstream ifs;
+ try
+ {
+ ifs.exceptions(std::ifstream::badbit);
+ ifs.open(image_list, std::ios::in | std::ios::binary);
+
+ // Parse image correctly classified labels
+ unsigned int counter = 0;
+ for(std::string line; !std::getline(ifs, line).fail() && counter <= end; ++counter)
+ {
+ // Add label if within range
+ if(counter >= start)
+ {
+ std::stringstream linestream(line);
+ std::string image_name;
+ int result;
+
+ linestream >> image_name >> result;
+ _results.emplace_back(result);
+ }
+ }
+ }
+ catch(const std::ifstream::failure &e)
+ {
+ ARM_COMPUTE_ERROR("Accessing %s: %s", image_list.c_str(), e.what());
+ }
+}
+
+void ValidationOutputAccessor::reset()
+{
+ _offset = 0;
+ _positive_samples = 0;
+}
+
+bool ValidationOutputAccessor::access_tensor(arm_compute::ITensor &tensor)
+{
+ if(_offset < _results.size())
+ {
+ // Get results
+ std::vector<size_t> tensor_results;
+ switch(tensor.info()->data_type())
+ {
+ case DataType::QASYMM8:
+ tensor_results = access_predictions_tensor<uint8_t>(tensor);
+ break;
+ case DataType::F32:
+ tensor_results = access_predictions_tensor<float>(tensor);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("NOT SUPPORTED!");
+ }
+
+ // Check if tensor results are within top-n accuracy
+ size_t correct_label = _results[_offset++];
+ auto is_valid_label = [&](size_t label)
+ {
+ return label == correct_label;
+ };
+
+ if(std::any_of(std::begin(tensor_results), std::begin(tensor_results) + _top_n - 1, is_valid_label))
+ {
+ ++_positive_samples;
+ }
+ }
+
+ // Report top_n accuracy
+ bool ret = _offset >= _results.size();
+ if(ret)
+ {
+ size_t total_samples = _results.size();
+ size_t negative_samples = total_samples - _positive_samples;
+ float accuracy = _positive_samples / static_cast<float>(total_samples);
+
+ _output_stream << "----------Top " << _top_n << " accuracy ----------" << std::endl
+ << std::endl;
+ _output_stream << "Positive samples : " << _positive_samples << std::endl;
+ _output_stream << "Negative samples : " << negative_samples << std::endl;
+ _output_stream << "Accuracy : " << accuracy << std::endl;
+ }
+
+ return ret;
+}
+
+template <typename T>
+std::vector<size_t> ValidationOutputAccessor::access_predictions_tensor(arm_compute::ITensor &tensor)
+{
+ // Get the predicted class
+ std::vector<size_t> index;
+
+ const auto output_net = reinterpret_cast<T *>(tensor.buffer() + tensor.info()->offset_first_element_in_bytes());
+ const size_t num_classes = tensor.info()->dimension(0);
+
+ index.resize(num_classes);
+
+ // Sort results
+ std::iota(std::begin(index), std::end(index), static_cast<size_t>(0));
+ std::sort(std::begin(index), std::end(index),
+ [&](size_t a, size_t b)
+ {
+ return output_net[a] > output_net[b];
+ });
+
+ return index;
+}
+
TopNPredictionsAccessor::TopNPredictionsAccessor(const std::string &labels_path, size_t top_n, std::ostream &output_stream)
: _labels(), _output_stream(output_stream), _top_n(top_n)
{
diff --git a/utils/GraphUtils.h b/utils/GraphUtils.h
index 349d8558fc..768c608d26 100644
--- a/utils/GraphUtils.h
+++ b/utils/GraphUtils.h
@@ -185,7 +185,7 @@ public:
* @param[in] start (Optional) Start range
* @param[in] end (Optional) End range
*
- * @note
+ * @note Range is defined as [start, end]
*/
ValidationInputAccessor(const std::string &image_list,
std::string images_path,
@@ -203,6 +203,49 @@ private:
size_t _offset;
};
+/** Output Accessor used for network validation */
+class ValidationOutputAccessor final : public graph::ITensorAccessor
+{
+public:
+ /** Default Constructor
+ *
+ * @param[in] image_list File containing all the images and labels results
+ * @param[in] top_n (Optional) Top N accuracy (Defaults to 5)
+ * @param[out] output_stream (Optional) Output stream (Defaults to the standard output stream)
+ * @param[in] start (Optional) Start range
+ * @param[in] end (Optional) End range
+ *
+ * @note Range is defined as [start, end]
+ */
+ ValidationOutputAccessor(const std::string &image_list,
+ size_t top_n = 5,
+ std::ostream &output_stream = std::cout,
+ unsigned int start = 0,
+ unsigned int end = 0);
+ /** Reset accessor state */
+ void reset();
+
+ // Inherited methods overriden:
+ bool access_tensor(ITensor &tensor) override;
+
+private:
+ /** Access predictions of the tensor
+ *
+ * @tparam T Tensor elements type
+ *
+ * @param[in] tensor Tensor to read the predictions from
+ */
+ template <typename T>
+ std::vector<size_t> access_predictions_tensor(ITensor &tensor);
+
+private:
+ std::vector<int> _results;
+ std::ostream &_output_stream;
+ size_t _top_n;
+ size_t _offset;
+ size_t _positive_samples;
+};
+
/** Result accessor class */
class TopNPredictionsAccessor final : public graph::ITensorAccessor
{