X-Git-Url: http://47.100.26.94:8080/?a=blobdiff_plain;f=src%2FPredictorWrapper.cpp;h=77d99affa6f087f8be96bf8700aa1314a6b2d0e4;hb=HEAD;hp=8c5965f7e5845b16a266713954fb846b985e8521;hpb=48adce31a0ffdb3757ee1be8a63ce7e769e87deb;p=trackerpp.git diff --git a/src/PredictorWrapper.cpp b/src/PredictorWrapper.cpp index 8c5965f..77d99af 100644 --- a/src/PredictorWrapper.cpp +++ b/src/PredictorWrapper.cpp @@ -1,6 +1,9 @@ #include "PredictorWrapper.h" #include +#include "Logger.h" #include +#include +#include "PyWrapper.h" namespace py = boost::python; @@ -8,97 +11,60 @@ using namespace std; using namespace suanzi; -PredictorWrapperWPtr PredictorWrapper::instance; +const static std::string TAG = "PredictorWrapper"; -const static std::string PREDICTOR_PY_DIR = "./python"; +template +boost::python::list toPythonList(const std::vector& v) { + typename std::vector::iterator iter; + boost::python::list list; + for(const auto& vv : v){ + list.append(vv); + } + return list; +} -static std::string parse_python_exception(); +#define CALL_WITH_EXCEPTION(expr, ret_val) \ + try{ \ + py::object py_ret = expr; \ + ret_val = py::extract(py_ret); \ + }catch(boost::python::error_already_set const &){ \ + LOG_ERROR(TAG, PyWrapper::parse_python_exception()); \ + throw runtime_error("Python error"); \ + } \ -PredictorWrapperPtr PredictorWrapper::create(const std::string& fname) +PredictorWrapper::PredictorWrapper(const string& module, const string& pydir) { - //if (instance == nullptr){ - if (instance.lock()){ - //instance = new PredictorWrapper(fname); - return PredictorWrapperPtr(); + PyWrapperPtr interpreter = PyWrapper::getInstance(pydir); + try{ + m_module = interpreter->import(module); + } catch (boost::python::error_already_set const &){ + LOG_ERROR(TAG, PyWrapper::parse_python_exception()); } - PredictorWrapperPtr ins (new PredictorWrapper(fname)); - instance = ins; - return ins; } -PredictorWrapper::PredictorWrapper(const std::string& fname) +bool PredictorWrapper::load(const string& fname) { - Py_Initialize(); - try{ - py::object main_module = py::import("__main__"); - py::object main_namespace = main_module.attr("__dict__"); - py::exec("import sys", main_namespace); - std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')"; - py::exec(cmd.c_str(), main_namespace); - //py::exec("sys.path.insert(0, '/home/debian/project/tracker/python')", main_namespace); - //py::exec("sys.path.insert(0, './python')", main_namespace); - py::object predictor_mod = py::import("predictor"); - py::object predictor_init = predictor_mod.attr("init"); - dump_func = predictor_mod.attr("dump"); - predict_func = predictor_mod.attr("predict"); - - predictor_init(fname.c_str()); - //predictor_dump(); - //py::exec("import predictor", main_namespace); - } catch (boost::python::error_already_set const &){ - std::string perror_str = parse_python_exception(); - std::cout << "Error in Python: " << perror_str << std::endl; - } + LOG_DEBUG(TAG, "load " + fname); + py::object func = m_module.attr("init"); + bool ret = true; + CALL_WITH_EXCEPTION(func(fname.c_str()), ret); + return ret; } -static std::string parse_python_exception(){ - PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL; - // Fetch the exception info from the Python C API - PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr); - - // Fallback error - std::string ret("Unfetchable Python error"); - // If the fetch got a type pointer, parse the type into the exception string - if(type_ptr != NULL){ - py::handle<> h_type(type_ptr); - py::str type_pstr(h_type); - // Extract the string from the boost::python object - py::extract e_type_pstr(type_pstr); - // If a valid string extraction is available, use it - // otherwise use fallback - if(e_type_pstr.check()) - ret = e_type_pstr(); - else - ret = "Unknown exception type"; - } - // Do the same for the exception value (the stringification of the exception) - if(value_ptr != NULL){ - py::handle<> h_val(value_ptr); - py::str a(h_val); - py::extract returned(a); - if(returned.check()) - ret += ": " + returned(); - else - ret += std::string(": Unparseable Python error: "); - } - // Parse lines from the traceback using the Python traceback module - if(traceback_ptr != NULL){ - py::handle<> h_tb(traceback_ptr); - // Load the traceback module and the format_tb function - py::object tb(py::import("traceback")); - py::object fmt_tb(tb.attr("format_tb")); - // Call format_tb to get a list of traceback strings - py::object tb_list(fmt_tb(h_tb)); - // Join the traceback strings into a single string - py::object tb_str(py::str("\n").join(tb_list)); - // Extract the string, check the extraction, and fallback in necessary - py::extract returned(tb_str); - if(returned.check()) - ret += ": " + returned(); - else - ret += std::string(": Unparseable Python traceback"); - } - return ret; -} +void PredictorWrapper::dump() +{ + LOG_DEBUG(TAG, "dump"); + py::object func = m_module.attr("dump"); + string ret = ""; + CALL_WITH_EXCEPTION(func(), ret); + LOG_DEBUG(TAG, ret); +} +double PredictorWrapper::predict(int index, const vector& features) +{ + py::object func = m_module.attr("predict"); + double rr = 0; + CALL_WITH_EXCEPTION(func(index, toPythonList(features)), rr); + return rr; +}