X-Git-Url: http://47.100.26.94:8080/?a=blobdiff_plain;f=src%2FPredictorWrapper.cpp;h=77d99affa6f087f8be96bf8700aa1314a6b2d0e4;hb=HEAD;hp=20fefc3e6326ec2737d539fd2d3dc0528ea18c10;hpb=3aa517d206c44156fe86697aeadc5f75ea212329;p=trackerpp.git diff --git a/src/PredictorWrapper.cpp b/src/PredictorWrapper.cpp index 20fefc3..77d99af 100644 --- a/src/PredictorWrapper.cpp +++ b/src/PredictorWrapper.cpp @@ -3,6 +3,7 @@ #include "Logger.h" #include #include +#include "PyWrapper.h" namespace py = boost::python; @@ -10,128 +11,60 @@ using namespace std; using namespace suanzi; -PredictorWrapperWPtr PredictorWrapper::instance; - -const static std::string PREDICTOR_PY_DIR = "./python"; const static std::string TAG = "PredictorWrapper"; -static std::string parse_python_exception(); - template -boost::python::list toPythonList(std::vector vector) { +boost::python::list toPythonList(const std::vector& v) { typename std::vector::iterator iter; boost::python::list list; - for (iter = vector.begin(); iter != vector.end(); ++iter) { - list.append(*iter); + for(const auto& vv : v){ + list.append(vv); } return list; } -PredictorWrapperPtr PredictorWrapper::create(const std::string& fname) -{ - if (instance.lock()){ - return PredictorWrapperPtr(); - } - PredictorWrapperPtr ins (new PredictorWrapper(fname)); - instance = ins; - return ins; -} +#define CALL_WITH_EXCEPTION(expr, ret_val) \ + try{ \ + py::object py_ret = expr; \ + ret_val = py::extract(py_ret); \ + }catch(boost::python::error_already_set const &){ \ + LOG_ERROR(TAG, PyWrapper::parse_python_exception()); \ + throw runtime_error("Python error"); \ + } \ -void PredictorWrapper::dump() + +PredictorWrapper::PredictorWrapper(const string& module, const string& pydir) { - LOG_DEBUG(TAG, "dump"); - std::string ss = ""; + PyWrapperPtr interpreter = PyWrapper::getInstance(pydir); try{ - py::object ret = this->dump_func(); - ss = py::extract(ret); + m_module = interpreter->import(module); } catch (boost::python::error_already_set const &){ - std::string perror_str = parse_python_exception(); - LOG_ERROR(TAG, "Error in Python: " + perror_str) + LOG_ERROR(TAG, PyWrapper::parse_python_exception()); } - LOG_DEBUG(TAG, ss); } -double PredictorWrapper::predict(int index, std::vector ff) +bool PredictorWrapper::load(const string& fname) { - LOG_DEBUG(TAG, "predict"); - try{ - this->predict_func(index, toPythonList(ff)); - } catch (boost::python::error_already_set const &){ - std::string perror_str = parse_python_exception(); - LOG_ERROR(TAG, "Error in Python: " + perror_str) - } - return 0.1; + LOG_DEBUG(TAG, "load " + fname); + py::object func = m_module.attr("init"); + bool ret = true; + CALL_WITH_EXCEPTION(func(fname.c_str()), ret); + return ret; } -PredictorWrapper::PredictorWrapper(const std::string& fname) +void PredictorWrapper::dump() { - Py_Initialize(); - try{ - py::object main_module = py::import("__main__"); - py::object main_namespace = main_module.attr("__dict__"); - py::exec("import sys", main_namespace); - std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')"; - py::exec(cmd.c_str(), main_namespace); - py::exec("import signal", main_namespace); - py::exec("signal.signal(signal.SIGINT, signal.SIG_DFL)", main_namespace); - py::object predictor_mod = py::import("predictor"); - py::object predictor_init = predictor_mod.attr("init"); - dump_func = predictor_mod.attr("dump"); - predict_func = predictor_mod.attr("predict"); - predictor_init(fname.c_str()); - } catch (boost::python::error_already_set const &){ - std::string perror_str = parse_python_exception(); - LOG_ERROR(TAG, "Error in Python: " + perror_str) - } + LOG_DEBUG(TAG, "dump"); + py::object func = m_module.attr("dump"); + string ret = ""; + CALL_WITH_EXCEPTION(func(), ret); + LOG_DEBUG(TAG, ret); } -static std::string parse_python_exception(){ - PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL; - // Fetch the exception info from the Python C API - PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr); - - // Fallback error - std::string ret("Unfetchable Python error"); - // If the fetch got a type pointer, parse the type into the exception string - if(type_ptr != NULL){ - py::handle<> h_type(type_ptr); - py::str type_pstr(h_type); - // Extract the string from the boost::python object - py::extract e_type_pstr(type_pstr); - // If a valid string extraction is available, use it - // otherwise use fallback - if(e_type_pstr.check()) - ret = e_type_pstr(); - else - ret = "Unknown exception type"; - } - // Do the same for the exception value (the stringification of the exception) - if(value_ptr != NULL){ - py::handle<> h_val(value_ptr); - py::str a(h_val); - py::extract returned(a); - if(returned.check()) - ret += ": " + returned(); - else - ret += std::string(": Unparseable Python error: "); - } - // Parse lines from the traceback using the Python traceback module - if(traceback_ptr != NULL){ - py::handle<> h_tb(traceback_ptr); - // Load the traceback module and the format_tb function - py::object tb(py::import("traceback")); - py::object fmt_tb(tb.attr("format_tb")); - // Call format_tb to get a list of traceback strings - py::object tb_list(fmt_tb(h_tb)); - // Join the traceback strings into a single string - py::object tb_str(py::str("\n").join(tb_list)); - // Extract the string, check the extraction, and fallback in necessary - py::extract returned(tb_str); - if(returned.check()) - ret += ": " + returned(); - else - ret += std::string(": Unparseable Python traceback"); - } - return ret; -} - +double PredictorWrapper::predict(int index, const vector& features) +{ + py::object func = m_module.attr("predict"); + double rr = 0; + CALL_WITH_EXCEPTION(func(index, toPythonList(features)), rr); + return rr; +}