Add test for Predictor
[trackerpp.git] / include / PredictorWrapper.h
1 #ifndef _PREDICTOR_H_
2 #define _PREDICTOR_H_
3
4 #include "SharedPtr.h"
5 #include <boost/python.hpp>
6 #include <vector>
7
8 namespace suanzi {
9
10     TK_DECLARE_PTR(PredictorWrapper);
11
12     typedef boost::python::object PY_FUN;
13
14     class PredictorWrapper
15     {
16     public:
17         static PredictorWrapperPtr create(const std::string& python_dir, const std::string& model_dir); // model.pkl file
18         ~PredictorWrapper(){}
19         void dump();
20         double predict(int index, const std::vector<double>& f);
21
22     private:
23         PredictorWrapper(const std::string& py_dir, const std::string& fname);
24         static PredictorWrapperWPtr instance;
25
26         PY_FUN dump_func;
27         PY_FUN predict_func;
28
29     };
30
31 }
32
33 #endif // _PREDICTOR_H_