gambit is hosted by Hepforge, IPPP Durham
GAMBIT  v1.5.0-2191-ga4742ac
a Global And Modular Bsm Inference Tool
twalk.cpp
Go to the documentation of this file.
1 // GAMBIT: Global and Modular BSM Inference Tool
2 // *********************************************
20 
21 #ifdef WITH_MPI
23 #include "mpi.h"
25 #endif
26 
27 #include "plugin_interface.hpp"
28 #include "scanner_plugin.hpp"
29 #include "twalk.hpp"
30 
31 scanner_plugin(twalk, version(1, 0, 1))
32 {
33  int plugin_main ()
34  {
35  like_ptr LogLike = get_purpose(get_inifile_value<std::string>("like", "LogLike"));
36 
37  // Do not allow GAMBIT's own likelihood calculator to directly shut down the scan.
38  // Twalk will assume responsibility for this process, triggered externally by
39  // the 'plugin_info.early_shutdown_in_progress()' function.
40  LogLike->disable_external_shutdown();
41 
42  int dim = get_dimension();
43  int numtasks;
44  #ifdef WITH_MPI
45  MPI_Comm_size(MPI_COMM_WORLD, &numtasks);
46  #else
47  numtasks = 1;
48  #endif
49 
50  Gambit::Options txt_options;
51  txt_options.setValue("synchronised",false);
52  get_printer().new_stream("txt", txt_options);
53  set_resume_params.set_resume_mode(get_printer().resume_mode());
54 
55  int pdim = get_inifile_value<int>("projection_dimension", 4);
56  TWalk(LogLike, get_printer(),
57  set_resume_params,
58  dim,
59  get_inifile_value<double>("kwalk_ratio", 0.9836),
60  pdim,
61  get_inifile_value<double>("gaussian_distance", 2.4),
62  get_inifile_value<double>("walk_distance", 2.5),
63  get_inifile_value<double>("traverse_distance", 6.0),
64  get_inifile_value<long long>("ran_seed", 0),
65  get_inifile_value<double>("sqrtR", 1.001),
66  get_inifile_value<int>("chain_number", 1 + pdim + numtasks),
67  get_inifile_value<bool>("hyper_grid", true),
68  get_inifile_value<int>("burn_in", 0),
69  get_inifile_value<int>("save_freq", 1000),
70  get_inifile_value<double>("timeout_mins", -1)
71  );
72 
73  return 0;
74  }
75 }
76 
77 
78 namespace Gambit
79 {
80  namespace Scanner
81  {
82  struct point_info
83  {
84  int mult;
85  int chain;
86  int rank;
87  unsigned long long int id;
88  };
89 
92  Gambit::Scanner::resume_params_func set_resume_params,
93  const int &dimension,
94  const double &div,
95  const int &proj,
96  const double &din,
97  const double &alim,
98  const double &alimt,
99  const long long &rand,
100  const double &sqrtR,
101  const int &NChains,
102  const bool &hyper_grid,
103  const int &burn_in,
104  const int &save_freq,
105  const double &mins_max)
106  {
107 
108  const double massiveR = 1e100;
109 
110  std::vector<double> chisq(NChains);
111  std::vector<double> aNext(dimension);
112  std::vector<std::vector<double>> a0(NChains, std::vector<double>(dimension));
113  double ans, chisqnext;
114  std::vector<int> mult(NChains, 1);
115  std::vector<int> totN(NChains, 0);
116  std::vector<int> count(NChains, 1);
117  int t, tt;
118  int total = 1, ttotal = 0;
119 
120  std::vector<std::vector<double>> covT(NChains, std::vector<double>(dimension, 0.0));
121  std::vector<std::vector<double>> avgT(NChains, std::vector<double>(dimension, 0.0));
122  std::vector<double> W(dimension, 0.0);
123  std::vector<double> avgTot(dimension, 0.0);
124  bool converged = false;
125  std::vector<unsigned long long int> ids(NChains);
126  std::vector<int> ranks(NChains);
127  unsigned long long int next_id;
128  double Rsum = massiveR, Rmax = massiveR;
129  bool resumed = false;
130 
131  unsigned int quit = 0; // signal for early shutdown
132 
133  std::chrono::time_point<std::chrono::system_clock> startTWalk;
134 
135  set_resume_params(chisq, a0, mult, totN, count, total, ttotal, covT, avgT, W, avgTot, ids, ranks, resumed);
136 
137  Gambit::Scanner::assign_aux_numbers("mult", "chain");
138 
139  int rank = set_resume_params.Rank();
140  int numtasks = set_resume_params.NumTasks();
141 
142  #ifdef WITH_MPI
143  std::vector<int> tints(NChains);
144  for (int i = 0; i < NChains; i++) tints[i] = i;
145  std::vector<int> talls(2*numtasks);
146  set_resume_params(tints, talls);
147  #endif
148 
149  std::ofstream temp_file_out;
150 
151  if (mins_max > 0 and rank == 0)
152  {
153  // Begin timing of TWalk run
154  startTWalk = std::chrono::system_clock::now();
155  }
156 
157  std::vector<RanNumGen *> gDev;
158  for (int i = 0; i < NChains; i++)
159  {
160  gDev.push_back(new RanNumGen(proj, dimension, din, alim, alimt, div, rand));
161  }
162 
163  // Try opening the temporary file for saving the mutliplicities etc.
164  str filename = set_resume_params.get_temp_file_name("temp");
165  temp_file_out.open(filename, std::ofstream::binary | std::ofstream::app);
166  if (not temp_file_out.is_open()) scan_error().raise(LOCAL_INFO, "Problem opening temp file " + filename + " in TWalk!");
167 
168  if (resumed)
169  {
170  #ifdef WITH_MPI
171  for (int i = 0; i < numtasks; i++)
172  {
173  MPI_Barrier(MPI_COMM_WORLD);
174  MPI_Bcast (c_ptr(a0[talls[i]]), a0[talls[i]].size(), MPI_DOUBLE, i, MPI_COMM_WORLD);
175  MPI_Bcast (&chisq[talls[i]], 1, MPI_DOUBLE, i, MPI_COMM_WORLD);
176  MPI_Bcast (&mult[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
177  MPI_Bcast (&count[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
178  MPI_Bcast (&ranks[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
179  MPI_Bcast (&ids[talls[i]], 1, MPI_UNSIGNED_LONG_LONG, i, MPI_COMM_WORLD);
180  }
181  #endif
182  }
183  else
184  {
185  resumed = true;
186  for (t = 0; t < NChains; t++)
187  {
188  if (rank == 0)
189  {
190  for (int j = 0; j < dimension; j++)
191  a0[t][j] = (gDev[t]->Doub());
192  chisq[t] = -LogLike(a0[t]);
194  ids[t] = LogLike->getPtID();
195  ranks[t] = rank;
196  }
197  #ifdef WITH_MPI
198  MPI_Barrier(MPI_COMM_WORLD);
199  MPI_Bcast (c_ptr(a0[t]), a0[t].size(), MPI_DOUBLE, 0, MPI_COMM_WORLD);
200  MPI_Bcast (&quit, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
201  #endif
202  if(quit)
203  {
204  std::cout
205  #ifdef WITH_MPI
206  <<"Rank "<<rank<<": "
207  #endif
208  <<"Quit signal received during TWalk chain initialisation, aborting run" << std::endl;
209  break;
210  }
211  }
212  }
213 
214  #ifdef WITH_MPI
215  MPI_Barrier(MPI_COMM_WORLD);
216  MPI_Bcast (c_ptr(chisq), chisq.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD);
217  MPI_Bcast (c_ptr(ids), ids.size(), MPI_UNSIGNED_LONG_LONG, 0, MPI_COMM_WORLD);
218  MPI_Bcast (c_ptr(ranks), ranks.size(), MPI_INT, 0, MPI_COMM_WORLD);
219  #endif
220 
221  std::cout << "Metropolis Hastings/TWalk Algorithm Started" << std::endl;
222 
223  while (not converged and not quit)
224  {
225  #ifdef WITH_MPI
226  if (rank == 0)
227  {
228  int j = NChains;
229  for(int i = 0; i < numtasks; i++)
230  {
231  int temp = int((j--)*gDev[0]->Doub());
232  talls[i] = tints[temp];
233  tints[temp] = tints[j];
234  tints[j] = talls[i];
235  }
236 
237  for(int i = numtasks, end = talls.size(); i < end; i++)
238  {
239  talls[i] = tints[int(j*gDev[0]->Doub())];
240  }
241  }
242 
243  MPI_Barrier(MPI_COMM_WORLD);
244  MPI_Bcast (c_ptr(talls), talls.size(), MPI_INT, 0, MPI_COMM_WORLD);
245  MPI_Bcast (c_ptr(tints), tints.size(), MPI_INT, 0, MPI_COMM_WORLD);
246 
247  t = talls[rank];
248  tt = talls[rank + numtasks];
249  double logZ = gDev[t]->Dev(aNext, a0, t, tt, NChains - numtasks, tints);
250  #else
251  t = int(NChains*gDev[0]->Doub());
252  tt = int((NChains - 1)*gDev[0]->Doub());
253  if (tt >= t) tt++;
254  double logZ = gDev[t]->Dev(aNext, a0, t, tt);
255  #endif
256 
257  if(!(hyper_grid && notUnit(aNext)))
258  {
259  chisqnext = -LogLike(aNext);
260  ans = chisqnext - chisq[t] - logZ;
261  next_id = LogLike->getPtID();
262  if ((ans <= 0.0)||(gDev[0]->ExpDev() >= ans))
263  {
264  //out_stream->print(mult[t], "mult", ranks[t], ids[t]);
265  //out_stream->print(t, "chain", ranks[t], ids[t]);
266  point_info info = {mult[t], t, ranks[t], ids[t]};
267  temp_file_out.write((char *)&info, sizeof(point_info));
268 
269  ids[t] = next_id;
270  a0[t] = aNext;
271  chisq[t] = chisqnext;
272  ranks[t] = rank;
273  mult[t] = 0;
274  count[t]++;
275  }
276  else
277  {
278  //out_stream->print(0, "mult", rank, next_id);
279  //out_stream->print(-1, "chain", rank, next_id);
280  point_info info = {0, -1, rank, next_id};
281  temp_file_out.write((char *)&info, sizeof(point_info));
282  }
283  }
284 
285  #ifdef WITH_MPI
286  for (int i = 0; i < numtasks; i++)
287  {
288  MPI_Barrier(MPI_COMM_WORLD);
289  MPI_Bcast (c_ptr(a0[talls[i]]), a0[talls[i]].size(), MPI_DOUBLE, i, MPI_COMM_WORLD);
290  MPI_Bcast (&chisq[talls[i]], 1, MPI_DOUBLE, i, MPI_COMM_WORLD);
291  MPI_Bcast (&mult[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
292  MPI_Bcast (&count[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
293  MPI_Bcast (&ranks[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
294  MPI_Bcast (&ids[talls[i]], 1, MPI_UNSIGNED_LONG_LONG, i, MPI_COMM_WORLD);
295  }
296  #endif
297 
298  for (int l = 0; l < NChains; l++) mult[l]++;
299 
300  total++;
301 
302  if (total%save_freq == 0)
303  {
304  set_resume_params.dump();
305  //out_stream->reset();
306  }
307 
308  if (rank == 0)
309  {
310  int cnt = 0;
311  for (auto it = count.begin(); it != count.end(); ++it)
312  {
313  cnt += *it;
314  }
315 
316  if (total%NChains == 0 && cnt >= burn_in*NChains)
317  {
318  for (int i = 0; i < NChains; i++) for (int j = 0; j < dimension; j++)
319  {
320  if (ttotal == 0)
321  {
322  covT[i][j] = avgT[i][j] = avgTot[j] = W[j] = 0.0;
323  }
324  else
325  {
326  double davg = (a0[i][j]-avgT[i][j])/(ttotal+1.0);
327  double dcov = ttotal*davg*davg - covT[i][j]/(ttotal+1.0);
328  avgTot[j] += davg/NChains;
329  covT[i][j] += dcov;
330  avgT[i][j] += davg;
331  W[j] += dcov/NChains;
332  }
333  }
334  ttotal++;
335 
336  // Loop over each dimension in the parameter space, and compute R for each.
337  // Trigger convergence only if R is below the requested threshold in every dimension.
338  Rsum = Rmax = 0.0;
339  converged = true;
340  for (int i = 0; i < dimension; i++)
341  {
342  double Bn = 0;
343  for (int ts = 0; ts < NChains; ts++)
344  {
345  Bn += (avgT[ts][i] - avgTot[i])*(avgT[ts][i] - avgTot[i]);
346  }
347  Bn /= double(NChains - 1);
348 
349  double R = W[i] > 0.0 ? 1.0 + double(NChains + 1)*Bn / (W[i] * double(NChains)) : massiveR;
350 
351  Rsum += R;
352  Rmax = std::max(Rmax, R);
353 
354  if (R < 1.0) scan_error().raise(LOCAL_INFO, "R < 1 in TWalk!");
355  if (R >= sqrtR*sqrtR) converged = false;
356  }
357  }
358 
359  // Check if the requested maximum runtime has been reached.
360  if (mins_max > 0)
361  {
362  std::chrono::duration<double> runtime = std::chrono::system_clock::now() - startTWalk;
363  double runtime_ms = std::chrono::duration_cast<std::chrono::milliseconds>(runtime).count();
364  if (runtime_ms / 60e3 >= mins_max)
365  {
366  std::cout << "TWalk reached requested time limit of " << mins_max << " minutes. Finalising run now." << std::endl;
367  converged = true;
368  }
369  }
370 
371  // Print out progress to stdout
372  if (converged or cnt % 100 == 0)
373  {
374  std::cout << "Points = " << cnt << " (" << cnt/double(NChains) << " per chain)" << std::endl;
375  std::cout << "\tAcceptance ratio = " << (double)cnt/(double)total/(double)numtasks << std::endl;
376  std::cout << "\tsqrt(R) (averaged over all dimensions) = " << sqrt(Rsum/dimension) << std::endl;
377  std::cout << "\tsqrt(R) (largest in any dimension) = " << sqrt(Rmax) << std::endl;
378  }
379 
380  }
381 
383  #ifdef WITH_MPI
384  MPI_Barrier(MPI_COMM_WORLD);
385  MPI_Bcast (&converged, 1, MPI_C_BOOL, 0, MPI_COMM_WORLD);
386  MPI_Bcast (&quit, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
387  #endif
388  if(quit)
389  {
390  std::cout
391  #ifdef WITH_MPI
392  <<"Rank "<<rank<<": "
393  #endif
394  <<"TWalk received quit signal! Terminating run." << std::endl;
395  }
396  }
397 
398  if(quit)
399  {
400  std::cout
401  #ifdef WITH_MPI
402  <<"Rank "<<rank<<": "
403  #endif
404  << "Writing resume data for TWalk" << std::endl;
405  // This is a bit awkward, but if we just call .dump() with no argument then
406  // ScannerBit will write resume data for ALL active plugins (which seems a
407  // weird thing for TWalk to trigger) and also try to finalise the printer
408  // (which will cause a crash when ScannerBit automatically tries to
409  // finalise the printer later).
410  // I would just add this dump stuff to scan.cpp, where the printer finalise
411  // is called, but I think that is AFTER the plugins are destructed, so
412  // I don't think that can work.
413  // Doing this will dump JUST the TWalk resume data, though I had to
414  // add this hacky get_name to the set_resume_params object in order to
415  // get the name by which ScannerBit identifies the TWalk plugin.
416  //Gambit::Scanner::Plugins::plugin_info.dump(set_resume_params.get_name());
417  set_resume_params.dump(); // Better way
418  // This works I think, but it still has problems. In particular,
419  // it looks like you must resume with the same number of processes
420  // that you started the run with, which is kind of crap.
421  }
422 
423  for (auto &&gd : gDev) delete gd;
424 
425  temp_file_out.close();
426  Gambit::Scanner::printer *out_stream = printer.get_stream("txt");
427  std::ifstream temp_file_in(set_resume_params.get_temp_file_name("temp").c_str(), std::ifstream::binary);
428  point_info info;
429  int i = 0;
430  while (temp_file_in.read((char *)&info, sizeof(point_info)))
431  {i++;
432  //std::cout<<"Twalk rank "<<rank<<" printing mult and chain for "<<i<<"th point of posterior chain (rank="<<info.rank<<", pointID="<<info.id<<")"<<std::endl;
433  out_stream->print(info.mult, "mult", info.rank, info.id);
434  out_stream->print(info.chain, "chain", info.rank, info.id);
435  }
436  out_stream->flush();
437 
438  std::cout << "TWalk has finished in process " << rank << "." << std::endl;
439 
440  return;
441  }
442 
443  }
444 
445 }
void print(T const &in, const std::string &label, const int vertexID, const uint rank, const ulong pointID)
DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry double
void TWalk(Gambit::Scanner::like_ptr LogLike, Gambit::Scanner::printer_interface &printer, Gambit::Scanner::resume_params_func set_resume_params, const int &dimension, const double &div, const int &proj, const double &din, const double &alim, const double &alimt, const long long &rand, const double &sqrtR, const int &NChains, const bool &hyper_grid, const int &burn_in, const int &save_freq, const double &hrs_max)
Definition: twalk.cpp:90
Interface details for scanner plugins.
std::string get_temp_file_name(const std::string &temp_file)
Definition: plugin_defs.hpp:84
#define LOCAL_INFO
Definition: local_info.hpp:34
scanner_plugin(twalk, version(1, 0, 1))
Definition: twalk.cpp:31
class to interface with the plugin manager resume functions.
Definition: plugin_defs.hpp:52
bool notUnit(const std::vector< double > &in)
Definition: twalk.hpp:33
Base functions objects for use in GAMBIT.
#define plugin_main(...)
Declaration of the main function which will be ran by the interface.
Simple header file for turning compiler warnings back on after having included one of the begin_ignor...
std::string str
Shorthand for a standard string.
Definition: Analysis.hpp:35
T::iterator::pointer c_ptr(T &it)
Definition: twalk.hpp:47
DS5_MSPCTM DS_INTDOF int
virtual BaseBasePrinter * get_stream(const std::string &="")=0
Getter for auxiliary printer objects.
Pragma directives to suppress compiler warnings coming from including MPI library headers...
EXPORT_SYMBOLS error & scan_error()
Scanner errors.
Manager class for creating printer objects.
declaration for scanner module
void setValue(const KEYTYPE &key, const VALTYPE &val)
Basic setter, for adding extra options.
EXPORT_SYMBOLS pluginInfo plugin_info
Access Functor for plugin info.
virtual void flush()=0
Signal printer to flush data in buffers to disk Printers should do this automatically as needed...
TODO: see if we can use this one:
Definition: Analysis.hpp:33
A small wrapper object for &#39;options&#39; nodes.
likelihood container for scanner plugins.