Add test for Predictor
[trackerpp.git] / include / PredictorWrapper.h
index fab9497..c41a8bc 100644 (file)
@@ -3,6 +3,7 @@
 
 #include "SharedPtr.h"
 #include <boost/python.hpp>
+#include <vector>
 
 namespace suanzi {
 
@@ -13,13 +14,13 @@ namespace suanzi {
     class PredictorWrapper
     {
     public:
-        static PredictorWrapperPtr create(const std::string& fname);
+        static PredictorWrapperPtr create(const std::string& python_dir, const std::string& model_dir); // model.pkl file
         ~PredictorWrapper(){}
-        void dump() { this->dump_func(); }
-        void predict() { this->predict_func();}
+        void dump();
+        double predict(int index, const std::vector<double>& f);
 
     private:
-        PredictorWrapper(const std::string& fname);
+        PredictorWrapper(const std::string& py_dir, const std::string& fname);
         static PredictorWrapperWPtr instance;
 
         PY_FUN dump_func;