Add boost python, and predictor wrapper
[trackerpp.git] / src / PredictorWrapper.cpp
1 #include "PredictorWrapper.h"
2 #include <string>
3 #include <iostream>
4
5 namespace py = boost::python;
6
7 using namespace std;
8 using namespace suanzi;
9
10
11 PredictorWrapperWPtr PredictorWrapper::instance;
12
13 const static std::string PREDICTOR_PY_DIR = "./python";
14
15 static std::string parse_python_exception();
16
17
18 PredictorWrapperPtr PredictorWrapper::create(const std::string& fname)
19 {
20     //if (instance == nullptr){
21     if (instance.lock()){
22         //instance = new PredictorWrapper(fname);
23         return PredictorWrapperPtr();
24     }
25     PredictorWrapperPtr ins (new PredictorWrapper(fname));
26     instance = ins;
27     return ins;
28 }
29
30 PredictorWrapper::PredictorWrapper(const std::string& fname)
31 {
32     Py_Initialize();
33     try{
34         py::object main_module = py::import("__main__");
35         py::object main_namespace = main_module.attr("__dict__");
36         py::exec("import sys", main_namespace);
37         std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')";
38         py::exec(cmd.c_str(), main_namespace);
39         //py::exec("sys.path.insert(0, '/home/debian/project/tracker/python')", main_namespace);
40         //py::exec("sys.path.insert(0, './python')", main_namespace);
41         py::object predictor_mod = py::import("predictor");
42         py::object predictor_init = predictor_mod.attr("init");
43         dump_func = predictor_mod.attr("dump");
44         predict_func = predictor_mod.attr("predict");
45
46         predictor_init(fname.c_str());
47         //predictor_dump();
48         //py::exec("import predictor", main_namespace);
49     } catch (boost::python::error_already_set const &){
50         std::string perror_str = parse_python_exception();
51         std::cout << "Error in Python: " << perror_str << std::endl;
52     }
53 }
54
55 static std::string parse_python_exception(){  
56     PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL;  
57     // Fetch the exception info from the Python C API  
58     PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr);  
59   
60     // Fallback error  
61     std::string ret("Unfetchable Python error");  
62     // If the fetch got a type pointer, parse the type into the exception string  
63     if(type_ptr != NULL){  
64         py::handle<> h_type(type_ptr);  
65         py::str type_pstr(h_type);  
66         // Extract the string from the boost::python object  
67         py::extract<std::string> e_type_pstr(type_pstr);  
68         // If a valid string extraction is available, use it   
69         //  otherwise use fallback  
70         if(e_type_pstr.check())  
71             ret = e_type_pstr();  
72         else  
73             ret = "Unknown exception type";  
74     }  
75     // Do the same for the exception value (the stringification of the exception)  
76     if(value_ptr != NULL){  
77         py::handle<> h_val(value_ptr);  
78         py::str a(h_val);  
79         py::extract<std::string> returned(a);  
80         if(returned.check())  
81             ret +=  ": " + returned();  
82         else  
83             ret += std::string(": Unparseable Python error: ");  
84     }  
85     // Parse lines from the traceback using the Python traceback module  
86     if(traceback_ptr != NULL){  
87         py::handle<> h_tb(traceback_ptr);  
88         // Load the traceback module and the format_tb function  
89         py::object tb(py::import("traceback"));  
90         py::object fmt_tb(tb.attr("format_tb"));  
91         // Call format_tb to get a list of traceback strings  
92         py::object tb_list(fmt_tb(h_tb));  
93         // Join the traceback strings into a single string  
94         py::object tb_str(py::str("\n").join(tb_list));  
95         // Extract the string, check the extraction, and fallback in necessary  
96         py::extract<std::string> returned(tb_str);  
97         if(returned.check())  
98             ret += ": " + returned();  
99         else  
100             ret += std::string(": Unparseable Python traceback");  
101     }  
102     return ret;  
103 }  
104