Add test for Predictor
[trackerpp.git] / src / PredictorWrapper.cpp
1 #include "PredictorWrapper.h"
2 #include <string>
3 #include "Logger.h"
4 #include <iostream>
5 #include <thread>
6
7 namespace py = boost::python;
8
9 using namespace std;
10 using namespace suanzi;
11
12
13 PredictorWrapperWPtr PredictorWrapper::instance;
14
15 const static std::string TAG = "PredictorWrapper";
16
17 static std::string parse_python_exception();
18
19 template <class T>
20 boost::python::list toPythonList(const std::vector<T>& v) {
21     typename std::vector<T>::iterator iter;
22     boost::python::list list;
23     for(const auto& vv : v){
24         list.append(vv);
25     }
26     return list;
27 }
28
29 PredictorWrapperPtr PredictorWrapper::create(const std::string& python_dir, const std::string& model_path)
30 {
31     if (instance.lock()){
32         return PredictorWrapperPtr();
33     }
34     PredictorWrapperPtr ins (new PredictorWrapper(python_dir, model_path));
35     instance = ins;
36     return ins;
37 }
38
39 void PredictorWrapper::dump()
40 {
41     LOG_DEBUG(TAG, "dump");
42     cout << "dump" << endl;
43     std::string ss = "";
44     try{
45         cout << "==== before" << endl;
46         py::object ret = this->dump_func();
47         ss = py::extract<std::string>(ret);
48     } catch (boost::python::error_already_set const &){
49         std::string perror_str = parse_python_exception();
50         LOG_ERROR(TAG, "Error in Python: " + perror_str)
51     }
52     LOG_DEBUG(TAG, ss);
53 }
54
55 double PredictorWrapper::predict(int index, const std::vector<double>& ff)
56 {
57     LOG_DEBUG(TAG, "predict");
58     py::object ret;
59     try{
60         ret = this->predict_func(index, toPythonList(ff));
61     } catch (boost::python::error_already_set const &){
62         std::string perror_str = parse_python_exception();
63         LOG_ERROR(TAG, "Error in Python: " + perror_str)
64     }
65     double rr = py::extract<double>(ret);
66     LOG_DEBUG(TAG, "return: " + std::to_string(rr));
67     return rr; 
68 }
69
70 PredictorWrapper::PredictorWrapper(const std::string& py_dir, const std::string& model_path)
71 {
72     Py_Initialize();
73     try{
74         py::object main_module = py::import("__main__");
75         py::object main_namespace = main_module.attr("__dict__");
76         py::exec("import sys", main_namespace);
77         std::string cmd = "sys.path.insert(0, '" + py_dir + "')";
78         py::exec(cmd.c_str(), main_namespace);
79         py::exec("import signal", main_namespace);
80         py::exec("signal.signal(signal.SIGINT, signal.SIG_DFL)", main_namespace);
81         py::object predictor_mod = py::import("predictor");
82         py::object predictor_init = predictor_mod.attr("init");
83         dump_func = predictor_mod.attr("dump");
84         predict_func = predictor_mod.attr("predict");
85         predictor_init(model_path.c_str());
86     } catch (boost::python::error_already_set const &){
87         std::string perror_str = parse_python_exception();
88         LOG_ERROR(TAG, "Error in Python: " + perror_str)
89     }
90 }
91
92 static std::string parse_python_exception(){  
93     PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL;  
94     // Fetch the exception info from the Python C API  
95     PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr);  
96   
97     // Fallback error  
98     std::string ret("Unfetchable Python error");  
99     // If the fetch got a type pointer, parse the type into the exception string  
100     if(type_ptr != NULL){  
101         py::handle<> h_type(type_ptr);  
102         py::str type_pstr(h_type);  
103         // Extract the string from the boost::python object  
104         py::extract<std::string> e_type_pstr(type_pstr);  
105         // If a valid string extraction is available, use it   
106         //  otherwise use fallback  
107         if(e_type_pstr.check())  
108             ret = e_type_pstr();  
109         else  
110             ret = "Unknown exception type";  
111     }  
112     // Do the same for the exception value (the stringification of the exception)  
113     if(value_ptr != NULL){  
114         py::handle<> h_val(value_ptr);  
115         py::str a(h_val);  
116         py::extract<std::string> returned(a);  
117         if(returned.check())  
118             ret +=  ": " + returned();  
119         else  
120             ret += std::string(": Unparseable Python error: ");  
121     }  
122     // Parse lines from the traceback using the Python traceback module  
123     if(traceback_ptr != NULL){  
124         py::handle<> h_tb(traceback_ptr);  
125         // Load the traceback module and the format_tb function  
126         py::object tb(py::import("traceback"));  
127         py::object fmt_tb(tb.attr("format_tb"));  
128         // Call format_tb to get a list of traceback strings  
129         py::object tb_list(fmt_tb(h_tb));  
130         // Join the traceback strings into a single string  
131         py::object tb_str(py::str("\n").join(tb_list));  
132         // Extract the string, check the extraction, and fallback in necessary  
133         py::extract<std::string> returned(tb_str);  
134         if(returned.check())  
135             ret += ": " + returned();  
136         else  
137             ret += std::string(": Unparseable Python traceback");  
138     }  
139     return ret;  
140 }  
141