Fix nan issue in features
[trackerpp.git] / src / PredictorWrapper.cpp
1 #include "PredictorWrapper.h"
2 #include <string>
3 #include "Logger.h"
4 #include <iostream>
5 #include <thread>
6 #include "PyWrapper.h"
7
8 namespace py = boost::python;
9
10 using namespace std;
11 using namespace suanzi;
12
13
14 const static std::string TAG = "PredictorWrapper";
15
16 template <class T>
17 boost::python::list toPythonList(const std::vector<T>& v) {
18     typename std::vector<T>::iterator iter;
19     boost::python::list list;
20     for(const auto& vv : v){
21         list.append(vv);
22     }
23     return list;
24 }
25
26 #define CALL_WITH_EXCEPTION(expr, ret_val)                      \
27     try{                                                        \
28         py::object py_ret = expr;                               \
29         ret_val = py::extract<decltype(ret_val)>(py_ret);       \
30     }catch(boost::python::error_already_set const &){           \
31         LOG_ERROR(TAG, PyWrapper::parse_python_exception());    \
32         throw runtime_error("Python error");                    \
33     }                                                           \
34
35
36 PredictorWrapper::PredictorWrapper(const string& module, const string& pydir)
37 {
38     PyWrapperPtr interpreter = PyWrapper::getInstance(pydir);
39     try{
40         m_module = interpreter->import(module);
41     } catch (boost::python::error_already_set const &){
42         LOG_ERROR(TAG, PyWrapper::parse_python_exception());
43     }
44 }
45
46 bool PredictorWrapper::load(const string& fname)
47 {
48     LOG_DEBUG(TAG, "load " + fname);
49     py::object func = m_module.attr("init");
50     bool ret = true;
51     CALL_WITH_EXCEPTION(func(fname.c_str()), ret);
52     return ret;
53 }
54
55 void PredictorWrapper::dump()
56 {
57     LOG_DEBUG(TAG, "dump");
58     py::object func = m_module.attr("dump");
59     string ret = "";
60     CALL_WITH_EXCEPTION(func(), ret);
61     LOG_DEBUG(TAG, ret);
62 }
63
64 double PredictorWrapper::predict(int index, const vector<double>& features)
65 {
66     py::object func = m_module.attr("predict");
67     double rr = 0;
68     CALL_WITH_EXCEPTION(func(index, toPythonList(features)), rr);
69     return rr;
70 }