X-Git-Url: http://47.100.26.94:8080/?a=blobdiff_plain;f=include%2FPredictorWrapper.h;h=18ea862e172bee6e828bbc1d530b04d23d22a504;hb=5c5910012c8e6e8fd4c0627c81acbacb1d06ba56;hp=ac54f733be99936cf96c1e0bb8e5ba93bb6e97c6;hpb=3ff9a5ad691b8dca9d91f8e9786a8d08d31b70fa;p=trackerpp.git diff --git a/include/PredictorWrapper.h b/include/PredictorWrapper.h index ac54f73..18ea862 100644 --- a/include/PredictorWrapper.h +++ b/include/PredictorWrapper.h @@ -3,6 +3,7 @@ #include "SharedPtr.h" #include +#include namespace suanzi { @@ -13,18 +14,14 @@ namespace suanzi { class PredictorWrapper { public: - static PredictorWrapperPtr create(const std::string& fname); + PredictorWrapper(const std::string& module="predictor", const std::string& pydir = "./python"); ~PredictorWrapper(){} void dump(); - double predict(); + double predict(int index, const std::vector& f); + bool load(const std::string& fname); // load pkl file private: - PredictorWrapper(const std::string& fname); - static PredictorWrapperWPtr instance; - - PY_FUN dump_func; - PY_FUN predict_func; - + boost::python::object m_module; }; }