separate pytwrpper from predictorWrappery
[trackerpp.git] / src / PredictorWrapper.cpp
1 #include "PredictorWrapper.h"
2 #include <string>
3 #include "Logger.h"
4 #include <iostream>
5 #include <thread>
6 #include "PyWrapper.h"
7
8 namespace py = boost::python;
9
10 using namespace std;
11 using namespace suanzi;
12
13
14 const static std::string TAG = "PredictorWrapper";
15
16 template <class T>
17 boost::python::list toPythonList(const std::vector<T>& v) {
18     typename std::vector<T>::iterator iter;
19     boost::python::list list;
20     for(const auto& vv : v){
21         list.append(vv);
22     }
23     return list;
24 }
25
26 #define CALL_WITH_EXCEPTION(expr, ret_val)                      \
27     try{                                                        \
28         py::object py_ret = expr;                               \
29         ret_val = py::extract<decltype(ret_val)>(py_ret);       \
30     }catch(boost::python::error_already_set const &){           \
31         LOG_ERROR(TAG, PyWrapper::parse_python_exception());    \
32     }                                                           \
33
34
35 PredictorWrapper::PredictorWrapper(const string& module, const string& pydir)
36 {
37     PyWrapperPtr interpreter = PyWrapper::getInstance(pydir);
38     try{
39         m_module = interpreter->import(module);
40     } catch (boost::python::error_already_set const &){
41         LOG_ERROR(TAG, PyWrapper::parse_python_exception());
42     }
43 }
44
45 bool PredictorWrapper::load(const string& fname)
46 {
47     LOG_DEBUG(TAG, "load " + fname);
48     py::object func = m_module.attr("init");
49     string ret;
50     CALL_WITH_EXCEPTION(func(fname.c_str()), ret);
51     return true;
52 }
53
54 void PredictorWrapper::dump()
55 {
56     LOG_DEBUG(TAG, "dump");
57     py::object func = m_module.attr("dump");
58     string ret = "";
59     CALL_WITH_EXCEPTION(func(), ret);
60     LOG_DEBUG(TAG, ret);
61 }
62
63 double PredictorWrapper::predict(int index, const vector<double>& features)
64 {
65     py::object func = m_module.attr("predict");
66     double rr = 0;
67     CALL_WITH_EXCEPTION(func(index, toPythonList(features)), rr);
68     return rr;
69 }