SLAMflex SE  0.1.0
SLAMflex provides detection and tracking of dominant planes for smartphone devices. This plane can then be used to show AR content relative to the plane orientation. The detection of plane is performed in the field of view of the smartphone camera. In subsequent frames it is tracked. The interface returns the plane position and orientation.
conjugate_gradient.h
Go to the documentation of this file.
1 #include "optimization/brent.h"
2 #include <utility>
3 #include <cmath>
4 #include <cassert>
5 #include <cstdlib>
6 
7 namespace TooN{
8  namespace Internal{
9 
10 
17  template<int Size, typename Precision, typename Func> struct LineSearch
18  {
21 
22  const Func& f;
23 
28  LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func)
29  :start(s),direction(d),f(func)
30  {}
31 
34  Precision operator()(Precision x) const
35  {
36  return f(start + x * direction);
37  }
38  };
39 
51  template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda, Precision zeps)
52  {
53  //Get a, b, c to bracket a minimum along a line
54  Precision a, b, c, b_val, c_val;
55 
56  a=0;
57 
58  //Search forward in steps of lambda
59  Precision lambda=initial_lambda;
60  b = lambda;
61  b_val = func(b);
62 
63  while(std::isnan(b_val))
64  {
65  //We've probably gone in to an invalid region. This can happen even
66  //if following the gradient would never get us there.
67  //try backing off lambda
68  lambda*=.5;
69  b = lambda;
70  b_val = func(b);
71 
72  }
73 
74 
75  if(b_val < a_val) //We've gone downhill, so keep searching until we go back up
76  {
77  double last_good_lambda = lambda;
78 
79  for(;;)
80  {
81  lambda *= 2;
82  c = lambda;
83  c_val = func(c);
84 
85  if(std::isnan(c_val))
86  break;
87  last_good_lambda = lambda;
88  if(c_val > b_val) // we have a bracket
89  break;
90  else
91  {
92  a = b;
93  a_val = b_val;
94  b=c;
95  b_val=c_val;
96 
97  }
98  }
99 
100  //We took a step too far.
101  //Back up: this will not attempt to ensure a bracket
102  if(std::isnan(c_val))
103  {
104  double bad_lambda=lambda;
105  double l=1;
106 
107  for(;;)
108  {
109  l*=.5;
110  c = last_good_lambda + (bad_lambda - last_good_lambda)*l;
111  c_val = func(c);
112 
113  if(!std::isnan(c_val))
114  break;
115  }
116 
117 
118  }
119 
120  }
121  else //We've overshot the minimum, so back up
122  {
123  c = b;
124  c_val = b_val;
125  //Here, c_val > a_val
126 
127  for(;;)
128  {
129  lambda *= .5;
130  b = lambda;
131  b_val = func(b);
132 
133  if(b_val < a_val)// we have a bracket
134  break;
135  else if(lambda < zeps)
136  return Zeros;
137  else //Contract the bracket
138  {
139  c = b;
140  c_val = b_val;
141  }
142  }
143  }
144 
145  Matrix<3,2> ret;
146  ret[0] = makeVector(a, a_val);
147  ret[1] = makeVector(b, b_val);
148  ret[2] = makeVector(c, c_val);
149 
150  return ret;
151  }
152 
153 }
154 
155 
200 template<int Size, class Precision=double> struct ConjugateGradient
201 {
202  const int size;
210  Precision y;
211  Precision old_y;
212 
213  Precision tolerance;
214  Precision epsilon;
216 
219  Precision linesearch_epsilon;
221 
222  Precision bracket_epsilon;
223 
225 
230  template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv)
231  : size(start.size()),
232  g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
233  {
234  init(start, func(start), deriv(start));
235  }
236 
241  template<class Func> ConjugateGradient(const Vector<Size>& start, const Func& func, const Vector<Size>& deriv)
242  : size(start.size()),
243  g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
244  {
245  init(start, func(start), deriv);
246  }
247 
252  void init(const Vector<Size>& start, const Precision& func, const Vector<Size>& deriv)
253  {
254 
255  using std::numeric_limits;
256  x = start;
257 
258  //Start with the conjugate direction aligned with
259  //the gradient
260  g = deriv;
261  h = g;
262  minus_h=-h;
263 
264  y = func;
265  old_y = y;
266 
267  tolerance = sqrt(numeric_limits<Precision>::epsilon());
268  epsilon = 1e-20;
269  max_iterations = size * 100;
270 
271  bracket_initial_lambda = 1;
272 
273  linesearch_tolerance = sqrt(numeric_limits<Precision>::epsilon());
274  linesearch_epsilon = 1e-20;
275  linesearch_max_iterations=100;
276 
277  bracket_epsilon=1e-20;
278 
279  iterations=0;
280  }
281 
282 
296  template<class Func> void find_next_point(const Func& func)
297  {
298  Internal::LineSearch<Size, Precision, Func> line(x, minus_h, func);
299 
300  //Always search in the conjugate direction (h)
301  //First bracket a minimum.
302  Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda, bracket_epsilon);
303 
304  double a = bracket[0][0];
305  double b = bracket[1][0];
306  double c = bracket[2][0];
307 
308  double a_val = bracket[0][1];
309  double b_val = bracket[1][1];
310  double c_val = bracket[2][1];
311 
312  old_y = y;
313  old_x = x;
314  iterations++;
315 
316  //Local maximum achieved!
317  if(a==0 && b== 0 && c == 0)
318  return;
319 
320  //We should have a bracket here
321 
322  if(c < b)
323  {
324  //Failed to bracket due to NaN, so c is the best known point.
325  //Simply go there.
326  x-=h * c;
327  y=c_val;
328 
329  }
330  else
331  {
332  assert(a < b && b < c);
333  assert(a_val > b_val && b_val < c_val);
334 
335  //Find the real minimum
336  Vector<2, Precision> m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon);
337 
338  assert(m[0] >= a && m[0] <= c);
339  assert(m[1] <= b_val);
340 
341  //Update the current position and value
342  x -= m[0] * h;
343  y = m[1];
344  }
345  }
346 
349  bool finished()
350  {
351  using std::abs;
352  return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon);
353  }
354 
363  void update_vectors_PR(const Vector<Size>& grad)
364  {
365  //Update the position, gradient and conjugate directions
366  old_g = g;
367  old_h = h;
368 
369  g = grad;
370  //Precision gamma = (g * g - oldg*g)/(oldg * oldg);
371  Precision gamma = (g * g - old_g*g)/(old_g * old_g);
372  h = g + gamma * old_h;
373  minus_h=-h;
374  }
375 
393  template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv)
394  {
395  find_next_point(func);
396 
397  if(!finished())
398  {
399  update_vectors_PR(deriv(x));
400  return 1;
401  }
402  else
403  return 0;
404  }
405 };
406 
407 }
const int size
Dimensionality of the space.
Precision bracket_initial_lambda
Initial stepsize used in bracketing the minimum for the line search. Defaults to 1.
int iterations
Number of iterations performed.
Precision operator()(Precision x) const
Vector< Size > g
Gradient vector used by the next call to iterate()
bool iterate(const Func &func, const Deriv &deriv)
Everything lives inside this namespace.
Definition: allocator.hh:48
Precision old_y
Function at old_x.
Vector< 2, Precision > brent_line_search(Precision a, Precision x, Precision b, Precision fx, const Functor &func, int maxiterations, Precision tolerance=sqrt(numeric_limits< Precision >::epsilon()), Precision epsilon=numeric_limits< Precision >::epsilon())
Definition: brent.h:29
Precision tolerance
Tolerance used to determine if the optimization is complete. Defaults to square root of machine preci...
void update_vectors_PR(const Vector< Size > &grad)
void init(const Vector< Size > &start, const Precision &func, const Vector< Size > &deriv)
Precision linesearch_tolerance
Tolerance used to determine if the linesearch is complete. Defaults to square root of machine precisi...
T abs(T t)
Definition: abs.h:30
Matrix< 3, 2, Precision > bracket_minimum_forward(Precision a_val, const Func &func, Precision initial_lambda, Precision zeps)
Vector< Size > old_g
Gradient vector used to compute $h$ in the last call to iterate()
Vector< Size > x
Current position (best known point)
Precision epsilon
Additive term in tolerance to prevent excessive iterations if . Known as ZEPS in numerical recipies...
Vector< Size > old_x
Previous best known point (not set at construction)
void find_next_point(const Func &func)
ConjugateGradient(const Vector< Size > &start, const Func &func, const Vector< Size > &deriv)
Vector< Size > minus_h
negative of h as this is required to be passed into a function which uses references (so can't be tem...
Precision bracket_epsilon
Minimum size for initial minima bracketing. Below this, it is assumed that the system has converged...
Vector< 1 > makeVector(double x1)
Definition: make_vector.hh:4
Vector< Size > h
Conjugate vector to be searched along in the next call to iterate()
Precision y
Function at .
ConjugateGradient(const Vector< Size > &start, const Func &func, const Deriv &deriv)
int max_iterations
Maximum number of iterations. Defaults to size .
Precision linesearch_epsilon
Additive term in tolerance to prevent excessive iterations if . Known as ZEPS in numerical recipies...
int linesearch_max_iterations
Maximum number of iterations in the linesearch. Defaults to 100.
Vector< Size > old_h
Conjugate vector searched along in the last call to iterate()
const Vector< Size, Precision > & direction
LineSearch(const Vector< Size, Precision > &s, const Vector< Size, Precision > &d, const Func &func)
const Vector< Size, Precision > & start
static Operator< Internal::Zero > Zeros
Definition: objects.h:727
bool isnan(const Vector< S, P, B > &v)
Definition: helpers.h:308