1 #include "PredictorWrapper.h"
8 namespace py = boost::python;
11 using namespace suanzi;
14 const static std::string TAG = "PredictorWrapper";
17 boost::python::list toPythonList(const std::vector<T>& v) {
18 typename std::vector<T>::iterator iter;
19 boost::python::list list;
20 for(const auto& vv : v){
26 #define CALL_WITH_EXCEPTION(expr, ret_val) \
28 py::object py_ret = expr; \
29 ret_val = py::extract<decltype(ret_val)>(py_ret); \
30 }catch(boost::python::error_already_set const &){ \
31 LOG_ERROR(TAG, PyWrapper::parse_python_exception()); \
32 throw runtime_error("Python error"); \
36 PredictorWrapper::PredictorWrapper(const string& module, const string& pydir)
38 PyWrapperPtr interpreter = PyWrapper::getInstance(pydir);
40 m_module = interpreter->import(module);
41 } catch (boost::python::error_already_set const &){
42 LOG_ERROR(TAG, PyWrapper::parse_python_exception());
46 bool PredictorWrapper::load(const string& fname)
48 LOG_DEBUG(TAG, "load " + fname);
49 py::object func = m_module.attr("init");
51 CALL_WITH_EXCEPTION(func(fname.c_str()), ret);
55 void PredictorWrapper::dump()
57 LOG_DEBUG(TAG, "dump");
58 py::object func = m_module.attr("dump");
60 CALL_WITH_EXCEPTION(func(), ret);
64 double PredictorWrapper::predict(int index, const vector<double>& features)
66 py::object func = m_module.attr("predict");
68 CALL_WITH_EXCEPTION(func(index, toPythonList(features)), rr);