1 #include "PredictorWrapper.h"
7 namespace py = boost::python;
10 using namespace suanzi;
13 PredictorWrapperWPtr PredictorWrapper::instance;
15 const static std::string TAG = "PredictorWrapper";
17 static std::string parse_python_exception();
20 boost::python::list toPythonList(const std::vector<T>& v) {
21 typename std::vector<T>::iterator iter;
22 boost::python::list list;
23 for(const auto& vv : v){
29 PredictorWrapperPtr PredictorWrapper::create(const std::string& python_dir, const std::string& model_path)
32 return PredictorWrapperPtr();
34 PredictorWrapperPtr ins (new PredictorWrapper(python_dir, model_path));
39 void PredictorWrapper::dump()
41 LOG_DEBUG(TAG, "dump");
42 cout << "dump" << endl;
45 cout << "==== before" << endl;
46 py::object ret = this->dump_func();
47 ss = py::extract<std::string>(ret);
48 } catch (boost::python::error_already_set const &){
49 std::string perror_str = parse_python_exception();
50 LOG_ERROR(TAG, "Error in Python: " + perror_str)
55 double PredictorWrapper::predict(int index, const std::vector<double>& ff)
57 LOG_DEBUG(TAG, "predict");
60 ret = this->predict_func(index, toPythonList(ff));
61 } catch (boost::python::error_already_set const &){
62 std::string perror_str = parse_python_exception();
63 LOG_ERROR(TAG, "Error in Python: " + perror_str)
65 double rr = py::extract<double>(ret);
66 LOG_DEBUG(TAG, "return: " + std::to_string(rr));
70 PredictorWrapper::PredictorWrapper(const std::string& py_dir, const std::string& model_path)
74 py::object main_module = py::import("__main__");
75 py::object main_namespace = main_module.attr("__dict__");
76 py::exec("import sys", main_namespace);
77 std::string cmd = "sys.path.insert(0, '" + py_dir + "')";
78 py::exec(cmd.c_str(), main_namespace);
79 py::exec("import signal", main_namespace);
80 py::exec("signal.signal(signal.SIGINT, signal.SIG_DFL)", main_namespace);
81 py::object predictor_mod = py::import("predictor");
82 py::object predictor_init = predictor_mod.attr("init");
83 dump_func = predictor_mod.attr("dump");
84 predict_func = predictor_mod.attr("predict");
85 predictor_init(model_path.c_str());
86 } catch (boost::python::error_already_set const &){
87 std::string perror_str = parse_python_exception();
88 LOG_ERROR(TAG, "Error in Python: " + perror_str)
92 static std::string parse_python_exception(){
93 PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL;
94 // Fetch the exception info from the Python C API
95 PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr);
98 std::string ret("Unfetchable Python error");
99 // If the fetch got a type pointer, parse the type into the exception string
100 if(type_ptr != NULL){
101 py::handle<> h_type(type_ptr);
102 py::str type_pstr(h_type);
103 // Extract the string from the boost::python object
104 py::extract<std::string> e_type_pstr(type_pstr);
105 // If a valid string extraction is available, use it
106 // otherwise use fallback
107 if(e_type_pstr.check())
110 ret = "Unknown exception type";
112 // Do the same for the exception value (the stringification of the exception)
113 if(value_ptr != NULL){
114 py::handle<> h_val(value_ptr);
116 py::extract<std::string> returned(a);
118 ret += ": " + returned();
120 ret += std::string(": Unparseable Python error: ");
122 // Parse lines from the traceback using the Python traceback module
123 if(traceback_ptr != NULL){
124 py::handle<> h_tb(traceback_ptr);
125 // Load the traceback module and the format_tb function
126 py::object tb(py::import("traceback"));
127 py::object fmt_tb(tb.attr("format_tb"));
128 // Call format_tb to get a list of traceback strings
129 py::object tb_list(fmt_tb(h_tb));
130 // Join the traceback strings into a single string
131 py::object tb_str(py::str("\n").join(tb_list));
132 // Extract the string, check the extraction, and fallback in necessary
133 py::extract<std::string> returned(tb_str);
135 ret += ": " + returned();
137 ret += std::string(": Unparseable Python traceback");