From db369d962b595544373b417ae9a76e7268eb12fb Mon Sep 17 00:00:00 2001 From: Peng Li Date: Sun, 22 Jul 2018 23:17:10 +0800 Subject: [PATCH] Add test for Predictor --- SConstruct | 2 +- include/PredictorWrapper.h | 4 ++-- include/Tracker.h | 2 ++ src/MultiTracker.cpp | 29 ++++++++++------------------- src/PredictorWrapper.cpp | 13 +++++++------ src/Tracker.cpp | 1 + test/SConscript | 7 +++++-- test/TestMain.cpp | 2 ++ test/TestPredictor.cpp | 12 ++++++++++++ 9 files changed, 42 insertions(+), 30 deletions(-) create mode 100644 test/TestPredictor.cpp diff --git a/SConstruct b/SConstruct index 75897cf..414727d 100644 --- a/SConstruct +++ b/SConstruct @@ -17,7 +17,7 @@ env.ParseConfig("pkg-config --libs opencv log4cpp") env.ParseConfig("python-config --cflags --libs") env.Append(LIBS = ['pthread', 'boost_python']) env['CCFLAGS'].remove('-Wstrict-prototypes') # invalid in C++ -env['CCFLAGS'].remove('-g') # invalid in C++ +env['CCFLAGS'].remove('-g') env.Append(LIBPATH=['#.']) diff --git a/include/PredictorWrapper.h b/include/PredictorWrapper.h index 41f818f..c41a8bc 100644 --- a/include/PredictorWrapper.h +++ b/include/PredictorWrapper.h @@ -14,13 +14,13 @@ namespace suanzi { class PredictorWrapper { public: - static PredictorWrapperPtr create(const std::string& fname); + static PredictorWrapperPtr create(const std::string& python_dir, const std::string& model_dir); // model.pkl file ~PredictorWrapper(){} void dump(); double predict(int index, const std::vector& f); private: - PredictorWrapper(const std::string& fname); + PredictorWrapper(const std::string& py_dir, const std::string& fname); static PredictorWrapperWPtr instance; PY_FUN dump_func; diff --git a/include/Tracker.h b/include/Tracker.h index 9eb3139..aaa3d1b 100644 --- a/include/Tracker.h +++ b/include/Tracker.h @@ -7,6 +7,7 @@ #include "Metrics.h" #include "SharedPtr.h" #include "MultiTracker.h" +#include "Detector.h" namespace suanzi { @@ -29,6 +30,7 @@ namespace suanzi { void addPatch(PatchPtr p); TrackerStatus status; std::vector patches; + Detection detection; private: TrackerStatus preStatus; diff --git a/src/MultiTracker.cpp b/src/MultiTracker.cpp index f3a7cad..a7f5e59 100644 --- a/src/MultiTracker.cpp +++ b/src/MultiTracker.cpp @@ -12,18 +12,16 @@ using namespace std; static const std::string TAG = "MultiTracker"; static const cv::Size PREFERRED_SIZE = Size(64, 128); -#define MaxCost 100000 +static const double MaxCost = 100000; +static const int MaxPath = 5; MultiTracker::MultiTracker(EngineWPtr e) : engine(e) { LOG_DEBUG(TAG, "init - loading model.pkl"); - predictor = PredictorWrapper::create("./python/model.pkl"); + predictor = PredictorWrapper::create("./python", "./python/model.pkl"); predictor->dump(); this->descriptor = {Size(64, 128), Size(16, 16), Size(8, 8), Size(8, 8), 9}; - - std::vector ff (40, 1); - double prob = predictor->predict(4, ff); } MultiTracker::~MultiTracker() @@ -66,7 +64,6 @@ static std::vector similarity(const PatchPtr p1, const PatchPtr p2) return feature; } - double MultiTracker::distance(TrackerPtr tracker, const cv::Mat& image, const Detection& d) { PatchPtr patch = createPatch(image, d); @@ -81,27 +78,21 @@ double MultiTracker::distance(TrackerPtr tracker, const cv::Mat& image, const De return prob; } -static long cc = 0; +static float calc_iou_ratio(const Detection& d1, const Detection& d2) +{ + return calc_iou_ratio(getRectInDetection(d1), getRectInDetection(d2)); +} void MultiTracker::update(unsigned int total, const Detection* detections, const Mat& image) { - ////// - if ((cc % 50) == 0){ - if (EnginePtr e = engine.lock()){ - e->onStatusChanged(); - } - } - cc++; - - ////// int row = trackers.size(); int col = total; Eigen::MatrixXi cost_matrix = Eigen::MatrixXi::Zero(row, col); for (int i = 0; i < row; i++){ for (int j = 0; j < col; j++){ - //if (calc_iou_ratio(trackers[i], detections[j]) < -0.1) - // cost_matrix(i, j) = MaxCost; - //else + if (calc_iou_ratio(trackers[i]->detection, detections[j]) < -0.1) + cost_matrix(i, j) = MaxCost; + else cost_matrix(i, j) = distance(trackers[i], image, detections[j]); } } diff --git a/src/PredictorWrapper.cpp b/src/PredictorWrapper.cpp index 2dc7eb4..44e5fed 100644 --- a/src/PredictorWrapper.cpp +++ b/src/PredictorWrapper.cpp @@ -12,7 +12,6 @@ using namespace suanzi; PredictorWrapperWPtr PredictorWrapper::instance; -const static std::string PREDICTOR_PY_DIR = "./python"; const static std::string TAG = "PredictorWrapper"; static std::string parse_python_exception(); @@ -27,12 +26,12 @@ boost::python::list toPythonList(const std::vector& v) { return list; } -PredictorWrapperPtr PredictorWrapper::create(const std::string& fname) +PredictorWrapperPtr PredictorWrapper::create(const std::string& python_dir, const std::string& model_path) { if (instance.lock()){ return PredictorWrapperPtr(); } - PredictorWrapperPtr ins (new PredictorWrapper(fname)); + PredictorWrapperPtr ins (new PredictorWrapper(python_dir, model_path)); instance = ins; return ins; } @@ -40,8 +39,10 @@ PredictorWrapperPtr PredictorWrapper::create(const std::string& fname) void PredictorWrapper::dump() { LOG_DEBUG(TAG, "dump"); + cout << "dump" << endl; std::string ss = ""; try{ + cout << "==== before" << endl; py::object ret = this->dump_func(); ss = py::extract(ret); } catch (boost::python::error_already_set const &){ @@ -66,14 +67,14 @@ double PredictorWrapper::predict(int index, const std::vector& ff) return rr; } -PredictorWrapper::PredictorWrapper(const std::string& fname) +PredictorWrapper::PredictorWrapper(const std::string& py_dir, const std::string& model_path) { Py_Initialize(); try{ py::object main_module = py::import("__main__"); py::object main_namespace = main_module.attr("__dict__"); py::exec("import sys", main_namespace); - std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')"; + std::string cmd = "sys.path.insert(0, '" + py_dir + "')"; py::exec(cmd.c_str(), main_namespace); py::exec("import signal", main_namespace); py::exec("signal.signal(signal.SIGINT, signal.SIG_DFL)", main_namespace); @@ -81,7 +82,7 @@ PredictorWrapper::PredictorWrapper(const std::string& fname) py::object predictor_init = predictor_mod.attr("init"); dump_func = predictor_mod.attr("dump"); predict_func = predictor_mod.attr("predict"); - predictor_init(fname.c_str()); + predictor_init(model_path.c_str()); } catch (boost::python::error_already_set const &){ std::string perror_str = parse_python_exception(); LOG_ERROR(TAG, "Error in Python: " + perror_str) diff --git a/src/Tracker.cpp b/src/Tracker.cpp index 8e5caaf..8ab2414 100644 --- a/src/Tracker.cpp +++ b/src/Tracker.cpp @@ -2,6 +2,7 @@ using namespace suanzi; using namespace cv; +using namespace std; static const int MaxLost = 5; diff --git a/test/SConscript b/test/SConscript index 1ca3bdc..d381c57 100644 --- a/test/SConscript +++ b/test/SConscript @@ -12,8 +12,11 @@ if arch == 'x86_64': else: env1.Append(LIBPATH = ['#third_party/googletest/lib']) -env1['LIBS'] = ['tracker', 'gtest', 'pthread'] -env1.ParseConfig("pkg-config --libs opencv") +env1.Append(LIBS = ['gtest']) +#env1['LIBS'] = ['tracker', 'gtest', 'pthread'] +#env1.ParseConfig("pkg-config --libs opencv") +#env.ParseConfig("pkg-config --libs opencv log4cpp") +#env.ParseConfig("python-config --cflags --libs") obj = env1.Object(Glob("*.cpp")) diff --git a/test/TestMain.cpp b/test/TestMain.cpp index 3568448..374ee16 100644 --- a/test/TestMain.cpp +++ b/test/TestMain.cpp @@ -1,4 +1,5 @@ #include "gtest/gtest.h" +#include "Logger.h" int main(int argc, char** argv) { // Disables elapsed time by default. @@ -6,6 +7,7 @@ int main(int argc, char** argv) { // This allows the user to override the flag on the command line. ::testing::InitGoogleTest(&argc, argv); + initLogger("../config/log4cpp.properties"); return RUN_ALL_TESTS(); } diff --git a/test/TestPredictor.cpp b/test/TestPredictor.cpp new file mode 100644 index 0000000..c6087f0 --- /dev/null +++ b/test/TestPredictor.cpp @@ -0,0 +1,12 @@ +#include "gtest/gtest.h" +#include "PredictorWrapper.h" + +using namespace suanzi; + +TEST(Predictor, load) +{ + PredictorWrapperPtr predictor = PredictorWrapper::create("../python", "../python/model.pkl"); + predictor->dump(); + std::vector ff (40, 1); + double prob = predictor->predict(4, ff); +} -- 2.11.0