35 like_ptr LogLike = get_purpose(get_inifile_value<std::string>(
"like",
"LogLike"));
40 LogLike->disable_external_shutdown();
42 int dim = get_dimension();
45 MPI_Comm_size(MPI_COMM_WORLD, &numtasks);
51 txt_options.
setValue(
"synchronised",
false);
52 get_printer().new_stream(
"txt", txt_options);
53 set_resume_params.set_resume_mode(get_printer().resume_mode());
55 int pdim = get_inifile_value<int>(
"projection_dimension", 4);
56 TWalk(LogLike, get_printer(),
59 get_inifile_value<double>(
"kwalk_ratio", 0.9836),
61 get_inifile_value<double>(
"gaussian_distance", 2.4),
62 get_inifile_value<double>(
"walk_distance", 2.5),
63 get_inifile_value<double>(
"traverse_distance", 6.0),
64 get_inifile_value<long long>(
"ran_seed", 0),
65 get_inifile_value<double>(
"sqrtR", 1.001),
66 get_inifile_value<int>(
"chain_number", 1 + pdim + numtasks),
67 get_inifile_value<bool>(
"hyper_grid",
true),
68 get_inifile_value<int>(
"burn_in", 0),
69 get_inifile_value<int>(
"save_freq", 1000),
70 get_inifile_value<double>(
"timeout_mins", -1)
87 unsigned long long int id;
99 const long long &rand,
102 const bool &hyper_grid,
104 const int &save_freq,
105 const double &mins_max)
108 const double massiveR = 1e100;
110 std::vector<double> chisq(NChains);
111 std::vector<double> aNext(dimension);
112 std::vector<std::vector<double>> a0(NChains, std::vector<double>(dimension));
113 double ans, chisqnext;
114 std::vector<int> mult(NChains, 1);
115 std::vector<int> totN(NChains, 0);
116 std::vector<int> count(NChains, 1);
118 int total = 1, ttotal = 0;
120 std::vector<std::vector<double>> covT(NChains, std::vector<double>(dimension, 0.0));
121 std::vector<std::vector<double>> avgT(NChains, std::vector<double>(dimension, 0.0));
122 std::vector<double> W(dimension, 0.0);
123 std::vector<double> avgTot(dimension, 0.0);
124 bool converged =
false;
125 std::vector<unsigned long long int>
ids(NChains);
126 std::vector<int> ranks(NChains);
127 unsigned long long int next_id;
128 double Rsum = massiveR, Rmax = massiveR;
129 bool resumed =
false;
131 unsigned int quit = 0;
133 std::chrono::time_point<std::chrono::system_clock> startTWalk;
135 set_resume_params(chisq, a0, mult, totN, count, total, ttotal, covT, avgT, W, avgTot, ids, ranks, resumed);
139 int rank = set_resume_params.
Rank();
140 int numtasks = set_resume_params.
NumTasks();
143 std::vector<int> tints(NChains);
144 for (
int i = 0; i < NChains; i++) tints[i] = i;
145 std::vector<int> talls(2*numtasks);
146 set_resume_params(tints, talls);
149 std::ofstream temp_file_out;
151 if (mins_max > 0 and rank == 0)
154 startTWalk = std::chrono::system_clock::now();
157 std::vector<RanNumGen *> gDev;
158 for (
int i = 0; i < NChains; i++)
160 gDev.push_back(
new RanNumGen(proj, dimension, din, alim, alimt, div, rand));
165 temp_file_out.open(filename, std::ofstream::binary | std::ofstream::app);
166 if (not temp_file_out.is_open())
scan_error().raise(
LOCAL_INFO,
"Problem opening temp file " + filename +
" in TWalk!");
171 for (
int i = 0; i < numtasks; i++)
173 MPI_Barrier(MPI_COMM_WORLD);
174 MPI_Bcast (
c_ptr(a0[talls[i]]), a0[talls[i]].size(), MPI_DOUBLE, i, MPI_COMM_WORLD);
175 MPI_Bcast (&chisq[talls[i]], 1, MPI_DOUBLE, i, MPI_COMM_WORLD);
176 MPI_Bcast (&mult[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
177 MPI_Bcast (&count[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
178 MPI_Bcast (&ranks[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
179 MPI_Bcast (&ids[talls[i]], 1, MPI_UNSIGNED_LONG_LONG, i, MPI_COMM_WORLD);
186 for (t = 0; t < NChains; t++)
190 for (
int j = 0; j < dimension; j++)
191 a0[t][j] = (gDev[t]->Doub());
192 chisq[t] = -LogLike(a0[t]);
194 ids[t] = LogLike->getPtID();
198 MPI_Barrier(MPI_COMM_WORLD);
199 MPI_Bcast (
c_ptr(a0[t]), a0[t].size(), MPI_DOUBLE, 0, MPI_COMM_WORLD);
200 MPI_Bcast (&quit, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
206 <<
"Rank "<<rank<<
": " 208 <<
"Quit signal received during TWalk chain initialisation, aborting run" << std::endl;
215 MPI_Barrier(MPI_COMM_WORLD);
216 MPI_Bcast (
c_ptr(chisq), chisq.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD);
217 MPI_Bcast (
c_ptr(ids), ids.size(), MPI_UNSIGNED_LONG_LONG, 0, MPI_COMM_WORLD);
218 MPI_Bcast (
c_ptr(ranks), ranks.size(), MPI_INT, 0, MPI_COMM_WORLD);
221 std::cout <<
"Metropolis Hastings/TWalk Algorithm Started" << std::endl;
223 while (not converged and not quit)
229 for(
int i = 0; i < numtasks; i++)
231 int temp =
int((j--)*gDev[0]->Doub());
232 talls[i] = tints[temp];
233 tints[temp] = tints[j];
237 for(
int i = numtasks, end = talls.size(); i < end; i++)
239 talls[i] = tints[
int(j*gDev[0]->Doub())];
243 MPI_Barrier(MPI_COMM_WORLD);
244 MPI_Bcast (
c_ptr(talls), talls.size(), MPI_INT, 0, MPI_COMM_WORLD);
245 MPI_Bcast (
c_ptr(tints), tints.size(), MPI_INT, 0, MPI_COMM_WORLD);
248 tt = talls[rank + numtasks];
249 double logZ = gDev[t]->Dev(aNext, a0, t, tt, NChains - numtasks, tints);
251 t =
int(NChains*gDev[0]->Doub());
252 tt =
int((NChains - 1)*gDev[0]->Doub());
254 double logZ = gDev[t]->Dev(aNext, a0, t, tt);
257 if(!(hyper_grid &&
notUnit(aNext)))
259 chisqnext = -LogLike(aNext);
260 ans = chisqnext - chisq[t] - logZ;
261 next_id = LogLike->getPtID();
262 if ((ans <= 0.0)||(gDev[0]->ExpDev() >= ans))
266 point_info
info = {mult[t], t, ranks[t], ids[t]};
267 temp_file_out.write((
char *)&info,
sizeof(point_info));
271 chisq[t] = chisqnext;
280 point_info
info = {0, -1,
rank, next_id};
281 temp_file_out.write((
char *)&info,
sizeof(point_info));
286 for (
int i = 0; i < numtasks; i++)
288 MPI_Barrier(MPI_COMM_WORLD);
289 MPI_Bcast (
c_ptr(a0[talls[i]]), a0[talls[i]].size(), MPI_DOUBLE, i, MPI_COMM_WORLD);
290 MPI_Bcast (&chisq[talls[i]], 1, MPI_DOUBLE, i, MPI_COMM_WORLD);
291 MPI_Bcast (&mult[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
292 MPI_Bcast (&count[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
293 MPI_Bcast (&ranks[talls[i]], 1, MPI_INT, i, MPI_COMM_WORLD);
294 MPI_Bcast (&ids[talls[i]], 1, MPI_UNSIGNED_LONG_LONG, i, MPI_COMM_WORLD);
298 for (
int l = 0; l < NChains; l++) mult[l]++;
302 if (total%save_freq == 0)
304 set_resume_params.
dump();
311 for (
auto it = count.begin(); it != count.end(); ++it)
316 if (total%NChains == 0 && cnt >= burn_in*NChains)
318 for (
int i = 0; i < NChains; i++)
for (
int j = 0; j < dimension; j++)
322 covT[i][j] = avgT[i][j] = avgTot[j] = W[j] = 0.0;
326 double davg = (a0[i][j]-avgT[i][j])/(ttotal+1.0);
327 double dcov = ttotal*davg*davg - covT[i][j]/(ttotal+1.0);
328 avgTot[j] += davg/NChains;
331 W[j] += dcov/NChains;
340 for (
int i = 0; i < dimension; i++)
343 for (
int ts = 0; ts < NChains; ts++)
345 Bn += (avgT[ts][i] - avgTot[i])*(avgT[ts][i] - avgTot[i]);
347 Bn /=
double(NChains - 1);
349 double R = W[i] > 0.0 ? 1.0 +
double(NChains + 1)*Bn / (W[i] *
double(NChains)) : massiveR;
352 Rmax = std::max(Rmax, R);
355 if (R >= sqrtR*sqrtR) converged =
false;
362 std::chrono::duration<double> runtime = std::chrono::system_clock::now() - startTWalk;
363 double runtime_ms = std::chrono::duration_cast<std::chrono::milliseconds>(runtime).count();
364 if (runtime_ms / 60e3 >= mins_max)
366 std::cout <<
"TWalk reached requested time limit of " << mins_max <<
" minutes. Finalising run now." << std::endl;
372 if (converged or cnt % 100 == 0)
374 std::cout <<
"Points = " << cnt <<
" (" << cnt/
double(NChains) <<
" per chain)" << std::endl;
375 std::cout <<
"\tAcceptance ratio = " << (
double)cnt/(
double)total/(
double)numtasks << std::endl;
376 std::cout <<
"\tsqrt(R) (averaged over all dimensions) = " << sqrt(Rsum/dimension) << std::endl;
377 std::cout <<
"\tsqrt(R) (largest in any dimension) = " << sqrt(Rmax) << std::endl;
384 MPI_Barrier(MPI_COMM_WORLD);
385 MPI_Bcast (&converged, 1, MPI_C_BOOL, 0, MPI_COMM_WORLD);
386 MPI_Bcast (&quit, 1, MPI_UNSIGNED, 0, MPI_COMM_WORLD);
392 <<
"Rank "<<rank<<
": " 394 <<
"TWalk received quit signal! Terminating run." << std::endl;
402 <<
"Rank "<<rank<<
": " 404 <<
"Writing resume data for TWalk" << std::endl;
417 set_resume_params.
dump();
423 for (
auto &&gd : gDev)
delete gd;
425 temp_file_out.close();
427 std::ifstream temp_file_in(set_resume_params.
get_temp_file_name(
"temp").c_str(), std::ifstream::binary);
430 while (temp_file_in.read((
char *)&info,
sizeof(point_info)))
433 out_stream->
print(info.mult,
"mult", info.rank, info.id);
434 out_stream->
print(info.chain,
"chain", info.rank, info.id);
438 std::cout <<
"TWalk has finished in process " << rank <<
"." << std::endl;
void print(T const &in, const std::string &label, const int vertexID, const uint rank, const ulong pointID)
DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry DecayTable::Entry double
void TWalk(Gambit::Scanner::like_ptr LogLike, Gambit::Scanner::printer_interface &printer, Gambit::Scanner::resume_params_func set_resume_params, const int &dimension, const double &div, const int &proj, const double &din, const double &alim, const double &alimt, const long long &rand, const double &sqrtR, const int &NChains, const bool &hyper_grid, const int &burn_in, const int &save_freq, const double &hrs_max)
Interface details for scanner plugins.
std::string get_temp_file_name(const std::string &temp_file)
scanner_plugin(twalk, version(1, 0, 1))
class to interface with the plugin manager resume functions.
bool notUnit(const std::vector< double > &in)
Base functions objects for use in GAMBIT.
#define plugin_main(...)
Declaration of the main function which will be ran by the interface.
bool early_shutdown_in_progress() const
Simple header file for turning compiler warnings back on after having included one of the begin_ignor...
std::string str
Shorthand for a standard string.
T::iterator::pointer c_ptr(T &it)
virtual BaseBasePrinter * get_stream(const std::string &="")=0
Getter for auxiliary printer objects.
Pragma directives to suppress compiler warnings coming from including MPI library headers...
EXPORT_SYMBOLS error & scan_error()
Scanner errors.
Manager class for creating printer objects.
void assign_aux_numbers()
declaration for scanner module
void setValue(const KEYTYPE &key, const VALTYPE &val)
Basic setter, for adding extra options.
EXPORT_SYMBOLS pluginInfo plugin_info
Access Functor for plugin info.
virtual void flush()=0
Signal printer to flush data in buffers to disk Printers should do this automatically as needed...
TODO: see if we can use this one:
A small wrapper object for 'options' nodes.
likelihood container for scanner plugins.