Fix issue boostpython not stopped by Ctrl+c
[trackerpp.git] / src / PredictorWrapper.cpp
1 #include "PredictorWrapper.h"
2 #include <string>
3 #include "Logger.h"
4 #include <iostream>
5 #include <thread>
6
7 namespace py = boost::python;
8
9 using namespace std;
10 using namespace suanzi;
11
12
13 PredictorWrapperWPtr PredictorWrapper::instance;
14
15 const static std::string PREDICTOR_PY_DIR = "./python";
16 const static std::string TAG = "PredictorWrapper";
17
18 static std::string parse_python_exception();
19
20
21 PredictorWrapperPtr PredictorWrapper::create(const std::string& fname)
22 {
23     if (instance.lock()){
24         return PredictorWrapperPtr();
25     }
26     PredictorWrapperPtr ins (new PredictorWrapper(fname));
27     instance = ins;
28     return ins;
29 }
30
31 void PredictorWrapper::dump()
32 {
33     LOG_DEBUG(TAG, "dump");
34     std::string ss = "";
35     try{
36         py::object ret = this->dump_func();
37         ss = py::extract<std::string>(ret);
38     } catch (boost::python::error_already_set const &){
39         std::string perror_str = parse_python_exception();
40         LOG_ERROR(TAG, "Error in Python: " + perror_str)
41     }
42     LOG_DEBUG(TAG, ss);
43 }
44
45 double PredictorWrapper::predict()
46 {
47     LOG_DEBUG(TAG, "predict");
48     try{
49         this->predict_func();
50     } catch (boost::python::error_already_set const &){
51         std::string perror_str = parse_python_exception();
52         LOG_ERROR(TAG, "Error in Python: " + perror_str)
53     }
54     return 0.1;
55 }
56
57 PredictorWrapper::PredictorWrapper(const std::string& fname)
58 {
59     Py_Initialize();
60     try{
61         py::object main_module = py::import("__main__");
62         py::object main_namespace = main_module.attr("__dict__");
63         py::exec("import sys", main_namespace);
64         std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')";
65         py::exec(cmd.c_str(), main_namespace);
66         py::exec("import signal", main_namespace);
67         py::exec("signal.signal(signal.SIGINT, signal.SIG_DFL)", main_namespace);
68         py::object predictor_mod = py::import("predictor");
69         py::object predictor_init = predictor_mod.attr("init");
70         dump_func = predictor_mod.attr("dump");
71         predict_func = predictor_mod.attr("predict");
72         predictor_init(fname.c_str());
73     } catch (boost::python::error_already_set const &){
74         std::string perror_str = parse_python_exception();
75         LOG_ERROR(TAG, "Error in Python: " + perror_str)
76     }
77 }
78
79 static std::string parse_python_exception(){  
80     PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL;  
81     // Fetch the exception info from the Python C API  
82     PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr);  
83   
84     // Fallback error  
85     std::string ret("Unfetchable Python error");  
86     // If the fetch got a type pointer, parse the type into the exception string  
87     if(type_ptr != NULL){  
88         py::handle<> h_type(type_ptr);  
89         py::str type_pstr(h_type);  
90         // Extract the string from the boost::python object  
91         py::extract<std::string> e_type_pstr(type_pstr);  
92         // If a valid string extraction is available, use it   
93         //  otherwise use fallback  
94         if(e_type_pstr.check())  
95             ret = e_type_pstr();  
96         else  
97             ret = "Unknown exception type";  
98     }  
99     // Do the same for the exception value (the stringification of the exception)  
100     if(value_ptr != NULL){  
101         py::handle<> h_val(value_ptr);  
102         py::str a(h_val);  
103         py::extract<std::string> returned(a);  
104         if(returned.check())  
105             ret +=  ": " + returned();  
106         else  
107             ret += std::string(": Unparseable Python error: ");  
108     }  
109     // Parse lines from the traceback using the Python traceback module  
110     if(traceback_ptr != NULL){  
111         py::handle<> h_tb(traceback_ptr);  
112         // Load the traceback module and the format_tb function  
113         py::object tb(py::import("traceback"));  
114         py::object fmt_tb(tb.attr("format_tb"));  
115         // Call format_tb to get a list of traceback strings  
116         py::object tb_list(fmt_tb(h_tb));  
117         // Join the traceback strings into a single string  
118         py::object tb_str(py::str("\n").join(tb_list));  
119         // Extract the string, check the extraction, and fallback in necessary  
120         py::extract<std::string> returned(tb_str);  
121         if(returned.check())  
122             ret += ": " + returned();  
123         else  
124             ret += std::string(": Unparseable Python traceback");  
125     }  
126     return ret;  
127 }  
128