[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest_3.hxx
1/************************************************************************/
2/* */
3/* Copyright 2014-2015 by Ullrich Koethe and Philip Schill */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35#ifndef VIGRA_RF3_HXX
36#define VIGRA_RF3_HXX
37
38#include <vector>
39#include <set>
40#include <map>
41#include <stack>
42#include <algorithm>
43
44#include "multi_array.hxx"
45#include "sampling.hxx"
46#include "threading.hxx"
47#include "threadpool.hxx"
48#include "random_forest_3/random_forest.hxx"
49#include "random_forest_3/random_forest_common.hxx"
50#include "random_forest_3/random_forest_visitors.hxx"
51
52namespace vigra
53{
54
55/** \addtogroup MachineLearning
56**/
57//@{
58
59/** \brief Random forest version 3.
60
61 This namespace contains VIGRA's 3rd version of the random forest classification/regression algorithm.
62 This version is much easier to customize than previous versions because it consequently separates
63 algorithms from the forest representation, following the design of the LEMON graph library.
64*/
65namespace rf3
66{
67
68template <typename FEATURES, typename LABELS>
69struct DefaultRF
70{
71 typedef RandomForest<FEATURES,
72 LABELS,
75};
76
77namespace detail
78{
79
80// In random forest training, you can store different items in the leaves,
81// depending on the accumulator. Typically, you want to store the class
82// distributions, but the ArgMaxAcc does not need this. The RFMapUpdater is
83// used to store only the necessary information.
84
85template <typename ACC>
86struct RFMapUpdater
87{
88 template <typename A, typename B>
89 void operator()(A & a, B const & b) const
90 {
91 a = b;
92 }
93};
94
95
96
97template <>
98struct RFMapUpdater<ArgMaxAcc>
99{
100 template <typename A, typename B>
101 void operator()(A & a, B const & b) const
102 {
103 auto it = std::max_element(b.begin(), b.end());
104 a = std::distance(b.begin(), it);
105 }
106};
107
108
109
110/// Loop over the split dimensions and compute the score of all considered splits.
111template <typename FEATURES, typename LABELS, typename SAMPLER, typename SCORER>
112void split_score(
113 FEATURES const & features,
114 LABELS const & labels,
115 std::vector<double> const & instance_weights,
116 std::vector<size_t> const & instances,
117 SAMPLER const & dim_sampler,
118 SCORER & score
119){
120 typedef typename FEATURES::value_type FeatureType;
121
122 auto feats = std::vector<FeatureType>(instances.size()); // storage for the features
123 auto sorted_indices = std::vector<size_t>(feats.size()); // storage for the index sort result
124 auto tosort_instances = std::vector<size_t>(feats.size()); // storage for the sorted instances
125
126 for (int i = 0; i < dim_sampler.sampleSize(); ++i)
127 {
128 size_t const d = dim_sampler[i];
129
130 // Copy the features to a vector with the correct size (so the sort is faster because of data locality).
131 for (size_t kk = 0; kk < instances.size(); ++kk)
132 feats[kk] = features(instances[kk], d);
133
134 // Sort the features.
138
139 // Get the score of the splits.
141 }
142}
143
144
145
146/**
147 * @brief Train a single randomized decision tree.
148 */
149template <typename RF, typename SCORER, typename VISITOR, typename STOP, typename RANDENGINE>
150void random_forest_single_tree(
151 typename RF::Features const & features,
152 MultiArray<1, size_t> const & labels,
153 RandomForestOptions const & options,
155 STOP stop,
156 RF & tree,
157 RANDENGINE const & randengine
158){
159 typedef typename RF::Features Features;
160 typedef typename Features::value_type FeatureType;
161 typedef LessEqualSplitTest<FeatureType> SplitTests;
162 typedef typename RF::Node Node;
163 typedef typename RF::ACC ACC;
164 typedef typename ACC::input_type ACCInputType;
165
166 static_assert(std::is_same<SplitTests, typename RF::SplitTests>::value,
167 "random_forest_single_tree(): Wrong Random Forest class.");
168
169 // the api is seriously broke...
170 int const num_instances = features.shape()[0];
171 size_t const num_features = features.shape()[1];
172 auto const & spec = tree.problem_spec_;
173
174 vigra_precondition(num_instances == labels.size(),
175 "random_forest_single_tree(): Shape mismatch between features and labels.");
176 vigra_precondition(num_features == spec.num_features_,
177 "random_forest_single_tree(): Wrong number of features.");
178
179 // Create the index vector for bookkeeping.
180 std::vector<size_t> instance_indices(num_instances);
181 std::iota(instance_indices.begin(), instance_indices.end(), 0);
182 typedef std::vector<size_t>::iterator InstanceIter;
183
184 // Create the weights for the bootstrap sample.
185 std::vector<double> instance_weights(num_instances, 1.0);
186 if (options.bootstrap_sampling_)
187 {
188 std::fill(instance_weights.begin(), instance_weights.end(), 0.0);
189 Sampler<MersenneTwister> sampler(num_instances,
190 SamplerOptions().withReplacement().stratified(options.use_stratification_),
191 &randengine);
192 sampler.sample();
193 for (int i = 0; i < sampler.sampleSize(); ++i)
194 {
195 int const index = sampler[i];
196 ++instance_weights[index];
197 }
198 }
199
200 // Multiply the instance weights by the class weights.
201 if (options.class_weights_.size() > 0)
202 {
203 for (size_t i = 0; i < instance_weights.size(); ++i)
204 instance_weights[i] *= options.class_weights_.at(labels(i));
205 }
206
207 // Create the sampler for the split dimensions.
208 auto const mtry = spec.actual_mtry_;
209 Sampler<MersenneTwister> dim_sampler(num_features, SamplerOptions().withoutReplacement().sampleSize(mtry), &randengine);
210
211 // Create the node stack and place the root node inside.
212 std::stack<Node> node_stack;
213 typedef std::pair<InstanceIter, InstanceIter> IterPair;
214 PropertyMap<Node, IterPair> instance_range; // begin and end of the instances of a node in the bookkeeping vector
215 PropertyMap<Node, std::vector<double> > node_distributions; // the class distributions in the nodes
216 PropertyMap<Node, size_t> node_depths; // the depth of each node
217 {
218 auto const rootnode = tree.graph_.addNode();
219 node_stack.push(rootnode);
220
222
223 std::vector<double> priors(spec.num_classes_, 0.0);
224 for (auto i : instance_indices)
225 priors[labels(i)] += instance_weights[i];
227
228 node_depths.insert(rootnode, 0);
229 }
230
231 // Call the visitor.
232 visitor.visit_before_tree(tree, features, labels, instance_weights);
233
234 // Split the nodes.
235 detail::RFMapUpdater<ACC> node_map_updater;
236 while (!node_stack.empty())
237 {
238 // Get the data of the current node.
239 auto const node = node_stack.top();
240 node_stack.pop();
241 auto const begin = instance_range.at(node).first;
242 auto const end = instance_range.at(node).second;
243 auto const & priors = node_distributions.at(node);
244 auto const depth = node_depths.at(node);
245
246 // Get the instances with weight > 0.
247 std::vector<size_t> used_instances;
248 for (auto it = begin; it != end; ++it)
249 if (instance_weights[*it] > 1e-10)
250 used_instances.push_back(*it);
251
252 // Find the best split.
253 dim_sampler.sample();
255 if (options.resample_count_ == 0 || used_instances.size() <= options.resample_count_)
256 {
257 // Find the split using all instances.
258 detail::split_score(
259 features,
260 labels,
264 score
265 );
266 }
267 else
268 {
269 // Generate a random subset of the instances.
270 Sampler<MersenneTwister> resampler(used_instances.begin(), used_instances.end(), SamplerOptions().withoutReplacement().sampleSize(options.resample_count_), &randengine);
271 resampler.sample();
272 auto indices = std::vector<size_t>(options.resample_count_);
273 for (size_t i = 0; i < options.resample_count_; ++i)
274 indices[i] = used_instances[resampler[i]];
275
276 // Find the split using the subset.
277 detail::split_score(
278 features,
279 labels,
281 indices,
283 score
284 );
285 }
286
287 // If no split was found, the node is terminal.
288 if (!score.split_found_)
289 {
290 tree.node_responses_.insert(node, ACCInputType());
291 node_map_updater(tree.node_responses_.at(node), node_distributions.at(node));
292 continue;
293 }
294
295 // Create the child nodes and split the instances accordingly.
296 auto const n_left = tree.graph_.addNode();
297 auto const n_right = tree.graph_.addNode();
298 tree.graph_.addArc(node, n_left);
299 tree.graph_.addArc(node, n_right);
300 auto const best_split = score.best_split_;
301 auto const best_dim = score.best_dim_;
302 auto const split_iter = std::partition(begin, end,
303 [&](size_t i)
304 {
305 return features(i, best_dim) <= best_split;
306 }
307 );
308
309 // Call the visitor.
310 visitor.visit_after_split(tree, features, labels, instance_weights, score, begin, split_iter, end);
311
312 instance_range.insert(n_left, IterPair(begin, split_iter));
314 tree.split_tests_.insert(node, SplitTests(best_dim, best_split));
315 node_depths.insert(n_left, depth+1);
316 node_depths.insert(n_right, depth+1);
317
318 // Compute the class distribution for the left child.
319 auto priors_left = std::vector<double>(spec.num_classes_, 0.0);
320 for (auto it = begin; it != split_iter; ++it)
321 priors_left[labels(*it)] += instance_weights[*it];
323
324 // Check if the left child is terminal.
325 if (stop(labels, RFNodeDescription<decltype(priors_left)>(depth+1, priors_left)))
326 {
327 tree.node_responses_.insert(n_left, ACCInputType());
328 node_map_updater(tree.node_responses_.at(n_left), node_distributions.at(n_left));
329 }
330 else
331 {
332 node_stack.push(n_left);
333 }
334
335 // Compute the class distribution for the right child.
336 auto priors_right = std::vector<double>(spec.num_classes_, 0.0);
337 for (auto it = split_iter; it != end; ++it)
338 priors_right[labels(*it)] += instance_weights[*it];
340
341 // Check if the right child is terminal.
342 if (stop(labels, RFNodeDescription<decltype(priors_right)>(depth+1, priors_right)))
343 {
344 tree.node_responses_.insert(n_right, ACCInputType());
345 node_map_updater(tree.node_responses_.at(n_right), node_distributions.at(n_right));
346 }
347 else
348 {
349 node_stack.push(n_right);
350 }
351 }
352
353 // Call the visitor.
354 visitor.visit_after_tree(tree, features, labels, instance_weights);
355}
356
357
358
359/// \brief Preprocess the labels and call the train functions on the single trees.
360template <typename FEATURES,
361 typename LABELS,
362 typename VISITOR,
363 typename SCORER,
364 typename STOP,
365 typename RANDENGINE>
367random_forest_impl(
368 FEATURES const & features,
369 LABELS const & labels,
370 RandomForestOptions const & options,
372 STOP const & stop,
374){
375 // typedef FEATURES Features;
376 typedef LABELS Labels;
377 // typedef typename Features::value_type FeatureType;
378 typedef typename Labels::value_type LabelType;
380
382 pspec.num_instances(features.shape()[0])
383 .num_features(features.shape()[1])
384 .actual_mtry(options.get_features_per_node(features.shape()[1]))
385 .actual_msample(labels.size());
386
387 // Check the number of trees.
388 size_t const tree_count = options.tree_count_;
389 vigra_precondition(tree_count > 0, "random_forest_impl(): tree_count must not be zero.");
390 std::vector<RF> trees(tree_count);
391
392 // Transform the labels to 0, 1, 2, ...
393 std::set<LabelType> const dlabels(labels.begin(), labels.end());
394 std::vector<LabelType> const distinct_labels(dlabels.begin(), dlabels.end());
395 pspec.distinct_classes(distinct_labels);
396 std::map<LabelType, size_t> label_map;
397 for (size_t i = 0; i < distinct_labels.size(); ++i)
398 {
400 }
401
403 for (size_t i = 0; i < (size_t)labels.size(); ++i)
404 {
405 transformed_labels(i) = label_map[labels(i)];
406 }
407
408 // Check the vector with the class weights.
409 vigra_precondition(options.class_weights_.size() == 0 || options.class_weights_.size() == distinct_labels.size(),
410 "random_forest_impl(): The number of class weights must be 0 or equal to the number of classes.");
411
412 // Write the problem specification into the trees.
413 for (auto & t : trees)
414 t.problem_spec_ = pspec;
415
416 // Find the correct number of threads.
417 size_t n_threads = 1;
418 if (options.n_threads_ >= 1)
419 n_threads = options.n_threads_;
420 else if (options.n_threads_ == -1)
421 n_threads = std::thread::hardware_concurrency();
422
423 // Use the global random engine to create seeds for the random engines that run in the threads.
425 std::set<UInt32> seeds;
426 while (seeds.size() < n_threads)
427 {
428 seeds.insert(rand_functor());
429 }
430 vigra_assert(seeds.size() == n_threads, "random_forest_impl(): Could not create random seeds.");
431
432 // Create the random engines that run in the threads.
433 std::vector<RANDENGINE> rand_engines;
434 for (auto seed : seeds)
435 {
436 rand_engines.push_back(RANDENGINE(seed));
437 }
438
439 // Call the visitor.
440 visitor.visit_before_training();
441
442 // Copy the visitor for each tree.
443 // We must change the type, since the original visitor chain holds references and therefore a default copy would be useless.
444 typedef typename VisitorCopy<VISITOR>::type VisitorCopyType;
445 std::vector<VisitorCopyType> tree_visitors;
446 for (size_t i = 0; i < tree_count; ++i)
447 {
448 tree_visitors.emplace_back(visitor);
449 }
450
451 // Train the trees.
452 ThreadPool pool((size_t)n_threads);
453 std::vector<threading::future<void> > futures;
454 for (size_t i = 0; i < tree_count; ++i)
455 {
456 futures.emplace_back(
457 pool.enqueue([&features, &transformed_labels, &options, &tree_visitors, &stop, &trees, i, &rand_engines](size_t thread_id)
458 {
459 random_forest_single_tree<RF, SCORER, VisitorCopyType, STOP>(features, transformed_labels, options, tree_visitors[i], stop, trees[i], rand_engines[thread_id]);
460 }
461 )
462 );
463 }
464 for (auto & fut : futures)
465 fut.get();
466
467 // Merge the trees together.
468 RF rf(trees[0]);
469 rf.options_ = options;
470 for (size_t i = 1; i < trees.size(); ++i)
471 {
472 rf.merge(trees[i]);
473 }
474
475 // Call the visitor.
476 visitor.visit_after_training(tree_visitors, rf, features, labels);
477
478 return rf;
479}
480
481
482
483/// \brief Get the stop criterion from the option object and pass it as template argument.
484template <typename FEATURES, typename LABELS, typename VISITOR, typename SCORER, typename RANDENGINE>
485inline
487random_forest_impl0(
488 FEATURES const & features,
489 LABELS const & labels,
490 RandomForestOptions const & options,
493){
494 if (options.max_depth_ > 0)
496 else if (options.min_num_instances_ > 1)
498 else if (options.node_complexity_tau_ > 0)
500 else
502}
503
504} // namespace detail
505
506/********************************************************/
507/* */
508/* random_forest */
509/* */
510/********************************************************/
511
512/** \brief Train a \ref vigra::rf3::RandomForest classifier.
513
514 This factory function constructs a \ref vigra::rf3::RandomForest classifier and trains
515 it for the given features and labels. They must be given as a matrix with shape
516 <tt>num_instances x num_features</tt> and an array with length <tt>num_instances</tt> respectively.
517 Most training options (such as number of trees in the forest, termination and split criteria,
518 and number of threads for parallel training) are specified via an option object of type \ref vigra::rf3::RandomForestOptions. Optional visitors are typically used to compute the
519 out-of-bag error of the classifier (use \ref vigra::rf3::OOBError) and estimate variable importance
520 on the basis of the Gini gain (use \ref vigra::rf3::VariableImportance). You can also provide
521 a specific random number generator instance, which is especially useful when you want to
522 enforce deterministic algorithm behavior during debugging.
523
524 <b> Declaration:</b>
525
526 \code
527 namespace vigra { namespace rf3 {
528 template <typename FEATURES,
529 typename LABELS,
530 typename VISITOR = vigra::rf3::RFStopVisiting,
531 typename RANDENGINE = vigra::MersenneTwister>
532 vigra::rf3::RandomForest<FEATURES, LABELS>
533 random_forest(
534 FEATURES const & features,
535 LABELS const & labels,
536 vigra::rf3::RandomForestOptions const & options = vigra::rf3::RandomForestOptions(),
537 VISITOR visitor = vigra::rf3::RFStopVisiting(),
538 RANDENGINE & randengine = vigra::MersenneTwister::global()
539 );
540 }}
541 \endcode
542
543 <b> Usage:</b>
544
545 <b>\#include</b> <vigra/random_forest_3.hxx><br>
546 Namespace: vigra::rf3
547
548 \code
549 using namespace vigra;
550
551 int num_instances = ...;
552 int num_features = ...;
553 MultiArray<2, double> train_features(Shape2(num_instances, num_features));
554 MultiArray<1, int> train_labels(Shape1(num_instances));
555 ... // fill training data matrices
556
557 rf3::OOBError oob; // visitor to compute the out-of-bag error
558 auto rf = random_forest(train_features, train_labels,
559 rf3::RandomForestOptions().tree_count(100)
560 .features_per_node(rf3::RF_SQRT)
561 .n_threads(4)
562 rf3::create_visitor(oob));
563
564 std::cout << "Random forest training finished with out-of-bag error " << oob.oob_err_ << "\n";
565
566 int num_test_instances = ...;
567 MultiArray<2, double> test_features(Shape2(num_test_instances, num_features));
568 MultiArray<1, int> test_labels(Shape1(num_test_instances));
569 ... // fill feature matrix for test data
570
571 rf.predict(test_features, test_labels);
572
573 for(int i=0; i<num_test_instances; ++i)
574 std::cerr << "Prediction for test instance " << i << ": " << test_labels(i) << "\n";
575 \endcode
576*/
577doxygen_overloaded_function(template <...> void random_forest)
578
579template <typename FEATURES, typename LABELS, typename VISITOR, typename RANDENGINE>
580inline
583 FEATURES const & features,
584 LABELS const & labels,
585 RandomForestOptions const & options,
588){
589 typedef detail::GeneralScorer<GiniScore> GiniScorer;
590 typedef detail::GeneralScorer<EntropyScore> EntropyScorer;
591 typedef detail::GeneralScorer<KolmogorovSmirnovScore> KSDScorer;
592 if (options.split_ == RF_GINI)
593 return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, GiniScorer, RANDENGINE>(features, labels, options, visitor, randengine);
594 else if (options.split_ == RF_ENTROPY)
595 return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, EntropyScorer, RANDENGINE>(features, labels, options, visitor, randengine);
596 else if (options.split_ == RF_KSD)
597 return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, KSDScorer, RANDENGINE>(features, labels, options, visitor, randengine);
598 else
599 throw std::runtime_error("random_forest(): Unknown split criterion.");
600}
601
602template <typename FEATURES, typename LABELS, typename VISITOR>
603inline
606 FEATURES const & features,
607 LABELS const & labels,
608 RandomForestOptions const & options,
610){
611 auto randengine = MersenneTwister::global();
612 return random_forest(features, labels, options, visitor, randengine);
613}
614
615template <typename FEATURES, typename LABELS>
616inline
617RandomForest<FEATURES, LABELS>
619 FEATURES const & features,
620 LABELS const & labels,
621 RandomForestOptions const & options
622){
623 RFStopVisiting stop;
624 return random_forest(features, labels, options, stop);
625}
626
627template <typename FEATURES, typename LABELS>
628inline
629RandomForest<FEATURES, LABELS>
631 FEATURES const & features,
632 LABELS const & labels
633){
634 return random_forest(features, labels, RandomForestOptions());
635}
636
637} // namespace rf3
638
639//@}
640
641} // namespace vigra
642
643#endif
Class for a single RGB value.
Definition rgbvalue.hxx:128
Options object for the random forest.
Definition rf_common.hxx:171
Options object for the Sampler class.
Definition sampling.hxx:64
Thread pool class to manage a set of parallel workers.
Definition threadpool.hxx:148
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
Class for fixed size vectors.
Definition tinyvector.hxx:1008
Random forest 'maximum depth' stop criterion.
Definition random_forest_common.hxx:477
Random forest 'node complexity' stop criterion.
Definition random_forest_common.hxx:525
Random forest 'number of datapoints' stop criterion.
Definition random_forest_common.hxx:500
Random forest 'node purity' stop criterion.
Definition random_forest_common.hxx:464
Options class for vigra::rf3::RandomForest version 3.
Definition random_forest_common.hxx:583
size_t get_features_per_node(size_t total) const
Get the actual number of features per node.
Definition random_forest_common.hxx:772
Random forest version 3.
Definition random_forest.hxx:69
void random_forest(...)
Train a vigra::rf3::RandomForest classifier.
void indexSort(Iterator first, Iterator last, IndexIterator index_first, Compare c)
Return the index permutation that would sort the input array.
Definition algorithm.hxx:414
void applyPermutation(IndexIterator index_first, IndexIterator index_last, InIterator in, OutIterator out)
Sort an array according to the given index permutation.
Definition algorithm.hxx:456

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.1