20fefc3e6326ec2737d539fd2d3dc0528ea18c10
[trackerpp.git] / src / PredictorWrapper.cpp
1 #include "PredictorWrapper.h"
2 #include <string>
3 #include "Logger.h"
4 #include <iostream>
5 #include <thread>
6
7 namespace py = boost::python;
8
9 using namespace std;
10 using namespace suanzi;
11
12
13 PredictorWrapperWPtr PredictorWrapper::instance;
14
15 const static std::string PREDICTOR_PY_DIR = "./python";
16 const static std::string TAG = "PredictorWrapper";
17
18 static std::string parse_python_exception();
19
20 template <class T>
21 boost::python::list toPythonList(std::vector<T> vector) {
22     typename std::vector<T>::iterator iter;
23     boost::python::list list;
24     for (iter = vector.begin(); iter != vector.end(); ++iter) {
25         list.append(*iter);
26     }
27     return list;
28 }
29
30 PredictorWrapperPtr PredictorWrapper::create(const std::string& fname)
31 {
32     if (instance.lock()){
33         return PredictorWrapperPtr();
34     }
35     PredictorWrapperPtr ins (new PredictorWrapper(fname));
36     instance = ins;
37     return ins;
38 }
39
40 void PredictorWrapper::dump()
41 {
42     LOG_DEBUG(TAG, "dump");
43     std::string ss = "";
44     try{
45         py::object ret = this->dump_func();
46         ss = py::extract<std::string>(ret);
47     } catch (boost::python::error_already_set const &){
48         std::string perror_str = parse_python_exception();
49         LOG_ERROR(TAG, "Error in Python: " + perror_str)
50     }
51     LOG_DEBUG(TAG, ss);
52 }
53
54 double PredictorWrapper::predict(int index, std::vector<double> ff)
55 {
56     LOG_DEBUG(TAG, "predict");
57     try{
58         this->predict_func(index, toPythonList(ff));
59     } catch (boost::python::error_already_set const &){
60         std::string perror_str = parse_python_exception();
61         LOG_ERROR(TAG, "Error in Python: " + perror_str)
62     }
63     return 0.1;
64 }
65
66 PredictorWrapper::PredictorWrapper(const std::string& fname)
67 {
68     Py_Initialize();
69     try{
70         py::object main_module = py::import("__main__");
71         py::object main_namespace = main_module.attr("__dict__");
72         py::exec("import sys", main_namespace);
73         std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')";
74         py::exec(cmd.c_str(), main_namespace);
75         py::exec("import signal", main_namespace);
76         py::exec("signal.signal(signal.SIGINT, signal.SIG_DFL)", main_namespace);
77         py::object predictor_mod = py::import("predictor");
78         py::object predictor_init = predictor_mod.attr("init");
79         dump_func = predictor_mod.attr("dump");
80         predict_func = predictor_mod.attr("predict");
81         predictor_init(fname.c_str());
82     } catch (boost::python::error_already_set const &){
83         std::string perror_str = parse_python_exception();
84         LOG_ERROR(TAG, "Error in Python: " + perror_str)
85     }
86 }
87
88 static std::string parse_python_exception(){  
89     PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL;  
90     // Fetch the exception info from the Python C API  
91     PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr);  
92   
93     // Fallback error  
94     std::string ret("Unfetchable Python error");  
95     // If the fetch got a type pointer, parse the type into the exception string  
96     if(type_ptr != NULL){  
97         py::handle<> h_type(type_ptr);  
98         py::str type_pstr(h_type);  
99         // Extract the string from the boost::python object  
100         py::extract<std::string> e_type_pstr(type_pstr);  
101         // If a valid string extraction is available, use it   
102         //  otherwise use fallback  
103         if(e_type_pstr.check())  
104             ret = e_type_pstr();  
105         else  
106             ret = "Unknown exception type";  
107     }  
108     // Do the same for the exception value (the stringification of the exception)  
109     if(value_ptr != NULL){  
110         py::handle<> h_val(value_ptr);  
111         py::str a(h_val);  
112         py::extract<std::string> returned(a);  
113         if(returned.check())  
114             ret +=  ": " + returned();  
115         else  
116             ret += std::string(": Unparseable Python error: ");  
117     }  
118     // Parse lines from the traceback using the Python traceback module  
119     if(traceback_ptr != NULL){  
120         py::handle<> h_tb(traceback_ptr);  
121         // Load the traceback module and the format_tb function  
122         py::object tb(py::import("traceback"));  
123         py::object fmt_tb(tb.attr("format_tb"));  
124         // Call format_tb to get a list of traceback strings  
125         py::object tb_list(fmt_tb(h_tb));  
126         // Join the traceback strings into a single string  
127         py::object tb_str(py::str("\n").join(tb_list));  
128         // Extract the string, check the extraction, and fallback in necessary  
129         py::extract<std::string> returned(tb_str);  
130         if(returned.check())  
131             ret += ": " + returned();  
132         else  
133             ret += std::string(": Unparseable Python traceback");  
134     }  
135     return ret;  
136 }  
137