X-Git-Url: http://47.100.26.94:8080/?a=blobdiff_plain;f=include%2FPredictorWrapper.h;h=c41a8bca6907dca5763a210baa45bc401c4bafb6;hb=db369d962b595544373b417ae9a76e7268eb12fb;hp=fab949730a247b40ffcc8a0af34925168f3db7be;hpb=48adce31a0ffdb3757ee1be8a63ce7e769e87deb;p=trackerpp.git diff --git a/include/PredictorWrapper.h b/include/PredictorWrapper.h index fab9497..c41a8bc 100644 --- a/include/PredictorWrapper.h +++ b/include/PredictorWrapper.h @@ -3,6 +3,7 @@ #include "SharedPtr.h" #include +#include namespace suanzi { @@ -13,13 +14,13 @@ namespace suanzi { class PredictorWrapper { public: - static PredictorWrapperPtr create(const std::string& fname); + static PredictorWrapperPtr create(const std::string& python_dir, const std::string& model_dir); // model.pkl file ~PredictorWrapper(){} - void dump() { this->dump_func(); } - void predict() { this->predict_func();} + void dump(); + double predict(int index, const std::vector& f); private: - PredictorWrapper(const std::string& fname); + PredictorWrapper(const std::string& py_dir, const std::string& fname); static PredictorWrapperWPtr instance; PY_FUN dump_func;