1 #include "PredictorWrapper.h"
7 namespace py = boost::python;
10 using namespace suanzi;
13 PredictorWrapperWPtr PredictorWrapper::instance;
15 const static std::string PREDICTOR_PY_DIR = "./python";
16 const static std::string TAG = "PredictorWrapper";
18 static std::string parse_python_exception();
21 PredictorWrapperPtr PredictorWrapper::create(const std::string& fname)
24 return PredictorWrapperPtr();
26 PredictorWrapperPtr ins (new PredictorWrapper(fname));
31 void PredictorWrapper::dump()
33 LOG_DEBUG(TAG, "dump");
36 py::object ret = this->dump_func();
37 ss = py::extract<std::string>(ret);
38 } catch (boost::python::error_already_set const &){
39 std::string perror_str = parse_python_exception();
40 LOG_ERROR(TAG, "Error in Python: " + perror_str)
45 double PredictorWrapper::predict()
47 LOG_DEBUG(TAG, "predict");
50 } catch (boost::python::error_already_set const &){
51 std::string perror_str = parse_python_exception();
52 LOG_ERROR(TAG, "Error in Python: " + perror_str)
57 PredictorWrapper::PredictorWrapper(const std::string& fname)
61 py::object main_module = py::import("__main__");
62 py::object main_namespace = main_module.attr("__dict__");
63 py::exec("import sys", main_namespace);
64 std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')";
65 py::exec(cmd.c_str(), main_namespace);
66 py::exec("import signal", main_namespace);
67 py::exec("signal.signal(signal.SIGINT, signal.SIG_DFL)", main_namespace);
68 py::object predictor_mod = py::import("predictor");
69 py::object predictor_init = predictor_mod.attr("init");
70 dump_func = predictor_mod.attr("dump");
71 predict_func = predictor_mod.attr("predict");
72 predictor_init(fname.c_str());
73 } catch (boost::python::error_already_set const &){
74 std::string perror_str = parse_python_exception();
75 LOG_ERROR(TAG, "Error in Python: " + perror_str)
79 static std::string parse_python_exception(){
80 PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL;
81 // Fetch the exception info from the Python C API
82 PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr);
85 std::string ret("Unfetchable Python error");
86 // If the fetch got a type pointer, parse the type into the exception string
88 py::handle<> h_type(type_ptr);
89 py::str type_pstr(h_type);
90 // Extract the string from the boost::python object
91 py::extract<std::string> e_type_pstr(type_pstr);
92 // If a valid string extraction is available, use it
93 // otherwise use fallback
94 if(e_type_pstr.check())
97 ret = "Unknown exception type";
99 // Do the same for the exception value (the stringification of the exception)
100 if(value_ptr != NULL){
101 py::handle<> h_val(value_ptr);
103 py::extract<std::string> returned(a);
105 ret += ": " + returned();
107 ret += std::string(": Unparseable Python error: ");
109 // Parse lines from the traceback using the Python traceback module
110 if(traceback_ptr != NULL){
111 py::handle<> h_tb(traceback_ptr);
112 // Load the traceback module and the format_tb function
113 py::object tb(py::import("traceback"));
114 py::object fmt_tb(tb.attr("format_tb"));
115 // Call format_tb to get a list of traceback strings
116 py::object tb_list(fmt_tb(h_tb));
117 // Join the traceback strings into a single string
118 py::object tb_str(py::str("\n").join(tb_list));
119 // Extract the string, check the extraction, and fallback in necessary
120 py::extract<std::string> returned(tb_str);
122 ret += ": " + returned();
124 ret += std::string(": Unparseable Python traceback");