From 48adce31a0ffdb3757ee1be8a63ce7e769e87deb Mon Sep 17 00:00:00 2001 From: Peng Li Date: Thu, 19 Jul 2018 02:28:57 +0800 Subject: [PATCH] Add boost python, and predictor wrapper --- README.md | 9 + SConstruct | 15 +- include/Detector.h | 3 +- include/Metrics.h | 18 +- include/MultiTracker.h | 2 +- include/PredictorWrapper.h | 32 ++++ include/hungarian.h | 9 + main.cpp | 5 +- python/model.pkl | 453 +++++++++++++++++++++++++++++++++++++++++++++ python/predictor.py | 35 ++++ src/Engine.cpp | 8 +- src/Metrics.cpp | 5 + src/MultiTracker.cpp | 70 ++++--- src/PredictorWrapper.cpp | 104 +++++++++++ src/hungarian.cpp | 12 ++ test/TestHungarian.cpp | 34 ++++ 16 files changed, 770 insertions(+), 44 deletions(-) create mode 100644 include/PredictorWrapper.h create mode 100644 python/model.pkl create mode 100644 python/predictor.py create mode 100644 src/PredictorWrapper.cpp diff --git a/README.md b/README.md index 13c713b..202ad06 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,15 @@ Tracker++ cpp version on Linux (arm) - log4cpp : logger utils - opencv - eigen : matrix library of C++ +- boost-python + - `apt-get install libpython-dev python-dev` + - build boost with python + - `pip install scipy numpy sklearn` + + +As boost has already + +`sudo ./b2 install link=static cxxflags=-fPIC --with-filesystem --with-test --with-log --with-program_options --with-python` `apt-get install liblog4cpp5-dev libopencv-dev libeigen3-dev` diff --git a/SConstruct b/SConstruct index fa54472..2a24ade 100644 --- a/SConstruct +++ b/SConstruct @@ -8,19 +8,24 @@ AddOption('--all', dest='all', action='store_true', help='Build all include test env = Environment(CXX="g++", CPPPATH=['#include'], - CCFLAGS=['-Wall', '-std=c++11', '-O2']) + CCFLAGS=['-Wall', '-std=c++11', '-O3']) + +env['ENV']['TERM'] = os.environ['TERM'] env.Append(LIBS = ['tracker']) env.ParseConfig("pkg-config --libs opencv log4cpp") -env.Append(LIBS = ['pthread']) +env.ParseConfig("python-config --cflags --libs") +env.Append(LIBS = ['pthread', 'boost_python']) env.Append(LIBPATH=['#.']) -obj = env.Object('main.cpp') - env.StaticLibrary('tracker', Glob('src/*.cpp')) +#tracker = env.StaticLibrary('tracker', Glob('src/*.cpp')) +#Depends(tracker, Glob('src/*.cpp')) -env.Program("main", list(obj)) +env.Program("main", 'main.cpp') +#main = env.Program("main", 'main.cpp') +#Depends(main, ["main.cpp", 'libtracker.a']) if GetOption('all'): SConscript('test/SConscript', exports='env') diff --git a/include/Detector.h b/include/Detector.h index edf875c..c1c5a9b 100644 --- a/include/Detector.h +++ b/include/Detector.h @@ -14,7 +14,8 @@ namespace suanzi { public: Detector(); virtual ~Detector(); - unsigned int detect(cv::Mat& frame, Detection* detections){return 1;} + // TODO + unsigned int detect(const cv::Mat& frame, Detection* detections){return 1;} }; struct Detection diff --git a/include/Metrics.h b/include/Metrics.h index c991d9f..3da8014 100644 --- a/include/Metrics.h +++ b/include/Metrics.h @@ -8,7 +8,8 @@ namespace suanzi { TK_DECLARE_PTR(Metrics); - TK_DECLARE_PTR(Patch); + //TK_DECLARE_PTR(Patch); + struct Patch; class Metrics { public: @@ -16,16 +17,23 @@ namespace suanzi { ~Metrics(){} const static long int MaxCost = 100000; const static int MaxPatch = 5; + void similarity(const Patch& p1, const Patch& p2); + private: cv::HOGDescriptor descriptor = {cv::Size(64, 128), cv::Size(16, 16), cv::Size(8, 8), cv::Size(8, 8), 9}; }; - class Patch + struct Patch { - public: - Patch(){}; - ~Patch(){}; + // bb_ltrb + + // + // image_crop + cv::Mat image_crop; + // + // features + }; } diff --git a/include/MultiTracker.h b/include/MultiTracker.h index 0fd1d3a..40744b8 100644 --- a/include/MultiTracker.h +++ b/include/MultiTracker.h @@ -20,7 +20,7 @@ namespace suanzi { private: MetricsPtr metrics; - std::set trackers; + std::vector trackers; int max_id = 0; void addTracker(TrackerPtr t); TrackerPtr createTracker(int id = 0); diff --git a/include/PredictorWrapper.h b/include/PredictorWrapper.h new file mode 100644 index 0000000..fab9497 --- /dev/null +++ b/include/PredictorWrapper.h @@ -0,0 +1,32 @@ +#ifndef _PREDICTOR_H_ +#define _PREDICTOR_H_ + +#include "SharedPtr.h" +#include + +namespace suanzi { + + TK_DECLARE_PTR(PredictorWrapper); + + typedef boost::python::object PY_FUN; + + class PredictorWrapper + { + public: + static PredictorWrapperPtr create(const std::string& fname); + ~PredictorWrapper(){} + void dump() { this->dump_func(); } + void predict() { this->predict_func();} + + private: + PredictorWrapper(const std::string& fname); + static PredictorWrapperWPtr instance; + + PY_FUN dump_func; + PY_FUN predict_func; + + }; + +} + +#endif // _PREDICTOR_H_ diff --git a/include/hungarian.h b/include/hungarian.h index 64b8596..df1f354 100644 --- a/include/hungarian.h +++ b/include/hungarian.h @@ -12,4 +12,13 @@ // @return the cost of the assignment int linear_sum_assignment(const Eigen::MatrixXi& cost_matrix, Eigen::VectorXi& row_ind, Eigen::VectorXi& col_ind); + +// Computes the consine distance between u and v +// https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cosine.html#scipy.spatial.distance.cosine +double distance_cosine(const Eigen::VectorXd& u, const Eigen::VectorXd& v); + +// Computes the Euclidean distance between u and v +// https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.euclidean.html#scipy.spatial.distance.euclidean +double distance_euclidean(const Eigen::VectorXd& u, const Eigen::VectorXd& v); + #endif // _HUNGARIAN_H_ diff --git a/main.cpp b/main.cpp index 1dbaef5..ddf49d6 100644 --- a/main.cpp +++ b/main.cpp @@ -19,18 +19,15 @@ class Callback : public EngineObserver }; }; - - int main(int argc, char* argv[]) { initLogger("./config/log4cpp.properties"); LOG_DEBUG(TAG, "=================================="); - EnginePtr e = Engine::create(); e->addObserver(new Callback()); e->setVideoSrc(VideoSrcType::URL, "rtsp://192.168.1.75:554/stream1"); e->start(); - e->destroy(); + log4cpp::Category::shutdown(); } diff --git a/python/model.pkl b/python/model.pkl new file mode 100644 index 0000000..e26d88b --- /dev/null +++ b/python/model.pkl @@ -0,0 +1,453 @@ +(lp1 +ccopy_reg +_reconstructor +p2 +(csklearn.linear_model.logistic +LogisticRegression +p3 +c__builtin__ +object +p4 +NtRp5 +(dp6 +S'warm_start' +p7 +I00 +sS'C' +F1 +sS'n_jobs' +p8 +I1 +sS'verbose' +p9 +I0 +sS'fit_intercept' +p10 +I01 +sS'solver' +p11 +S'liblinear' +p12 +sS'classes_' +p13 +cnumpy.core.multiarray +_reconstruct +p14 +(cnumpy +ndarray +p15 +(I0 +tS'b' +tRp16 +(I1 +(L2L +tcnumpy +dtype +p17 +(S'i4' +I0 +I1 +tRp18 +(I3 +S'<' +NNNI-1 +I-1 +I0 +tbI00 +S'\x00\x00\x00\x00\x01\x00\x00\x00' +tbsS'n_iter_' +p19 +g14 +(g15 +(I0 +tS'b' +tRp20 +(I1 +(L1L +tg17 +(S'i4' +I0 +I1 +tRp21 +(I3 +S'<' +NNNI-1 +I-1 +I0 +tbI00 +S'\x0b\x00\x00\x00' +tbsS'intercept_scaling' +p22 +I1 +sS'penalty' +p23 +S'l2' +p24 +sS'multi_class' +p25 +S'ovr' +p26 +sS'random_state' +p27 +NsS'_sklearn_version' +p28 +S'0.19.0' +p29 +sS'dual' +p30 +I00 +sS'tol' +p31 +F0.0001 +sS'coef_' +p32 +g14 +(g15 +(I0 +tS'b' +tRp33 +(I1 +(L1L +L8L +tg17 +(S'f8' +I0 +I1 +tRp34 +(I3 +S'<' +NNNI-1 +I-1 +I0 +tbI00 +S'\xe2\x06dQ\xb5\xd1\x11@\xb4\x06\x1c.{`\xfb?\xea}\x1c\xea \xb7\xbb?TQz6\x17\r\xdc\xbfc}J\xd5\xfb\x81\x0c@\x92\xe9\xf7\r/\xa3\xf4\xbf\xf5\xc9mcQ\xfe\x01\xc0:\xa9`\xd4\xd3\xfb\xf7\xbf' +tbsS'intercept_' +p35 +g14 +(g15 +(I0 +tS'b' +tRp36 +(I1 +(L1L +tg34 +I00 +S'\xee\xa1\x95\x9f\xd3\xeb}?' +tbsS'max_iter' +p37 +I100 +sS'class_weight' +p38 +Nsbag2 +(g3 +g4 +NtRp39 +(dp40 +g7 +I00 +sS'C' +F1 +sg8 +I1 +sg9 +I0 +sg10 +I01 +sg11 +g12 +sg13 +g14 +(g15 +(I0 +tS'b' +tRp41 +(I1 +(L2L +tg18 +I00 +S'\x00\x00\x00\x00\x01\x00\x00\x00' +tbsg19 +g14 +(g15 +(I0 +tS'b' +tRp42 +(I1 +(L1L +tg21 +I00 +S'\x0c\x00\x00\x00' +tbsg22 +I1 +sg23 +g24 +sg25 +g26 +sg27 +Nsg28 +g29 +sg30 +I00 +sg31 +F0.0001 +sg32 +g14 +(g15 +(I0 +tS'b' +tRp43 +(I1 +(L1L +L16L +tg34 +I00 +S'4\xf7\\\xbf\xe6$\x0c@biKO\x01\x00\xf6?\xff\x82\x07*\xe9\xf7\xd2?\xb8\xe3\xd7|\x14O\xdf\xbfKS\xd5?\x10\x00\xaf\xd6!\xa7\x81?\xb7\xbe\xa9,\xb0k\xec?\xf7\x18\x82?\xf8\xa6\xc2\xbfU\xe8Ni\xf2\xa2\xfa?\xff\x84E\xf7!\x98\xdf\xbf\xbb\x91N=\x86\xfd\xf1?\xd7/\xef\xb4\xb3\xc2\xe0?\x0c\x96\x13\x04Al\xd4?\xa4\xd3a\x12\x1d\xce\xc1\xbf\xd2(9\xcd\x9dh\xef?H /\xedC\xcf\xda?J\n\xa1\x1fD\x98\xc1?m%\xc3&\xaa;\xdb?' +tbsg35 +g14 +(g15 +(I0 +tS'b' +tRp50 +(I1 +(L1L +tg34 +I00 +S'\x80\x1d\xf1\xf3(\x89\xe3\xbf' +tbsg37 +I100 +sg38 +Nsbag2 +(g3 +g4 +NtRp51 +(dp52 +g7 +I00 +sS'C' +F1 +sg8 +I1 +sg9 +I0 +sg10 +I01 +sg11 +g12 +sg13 +g14 +(g15 +(I0 +tS'b' +tRp53 +(I1 +(L2L +tg18 +I00 +S'\x00\x00\x00\x00\x01\x00\x00\x00' +tbsg19 +g14 +(g15 +(I0 +tS'b' +tRp54 +(I1 +(L1L +tg21 +I00 +S'\r\x00\x00\x00' +tbsg22 +I1 +sg23 +g24 +sg25 +g26 +sg27 +Nsg28 +g29 +sg30 +I00 +sg31 +F0.0001 +sg32 +g14 +(g15 +(I0 +tS'b' +tRp55 +(I1 +(L1L +L32L +tg34 +I00 +S'N+\xd0R\xed\x8f\x06@\xe0\x9f\xab\xc1\x16\x8f\xf4?\xbe+\xed&\xea\r\xd0?\xb6\xfc,\x074\x8b\xe2\xbf\x8f>f\x8e\xa3\x07\x01@\xe7lu\xbbn\x9b\xf1\xbf>[@s\xc7:\x0e\xc0\x86"\xb8\xaa,\xdb\x85?\xec\xddMJ)\x1a\xea?\xec*\xe1\xfc.\xc5\xd4?\t\xef,\xd8\xf41\xd4?J[\'\x8fP\x14\xab\xbf\x8fA\xea\xd8\xf0b\xe6?3\x8a|\t1\x1c\xc1\xbfn6\xdf\x96\xc6\xe2\xfa?&\x04\x82\x14\x1f5\xdb\xbf\xcc+\x18\xfa\x9f\x0b\xea?\xc6\x1c\xcf\xec\xcf\xa0\xd8?\n\xa5\xbe\xb2\xd7\x1c\xd4?\x82c M\x85\xf6\xbb\xbf\x1b\xb6\x17\x972"\xe1?U\x90/\x04\x8f\xdb\xd2?\x93\xceVv\xe2\x9f\x9c?$p\x16#o\x93\xd7?\x0e\x1a\xfe=\xe7\xc3\xea?h\xfa\xee\xd04y\xdf?\xf8\xad\xff\x86\xf0\xc4\xd1?\xa4N\xb3<87\xb4?\xb8\xc9\xce@\xdc(\xc4?HSA\x9b\xd1I\xcb?g\n\xd9tn\x10\xd0?t\x9db\x95u\xd0\xdc?' +tbsg35 +g14 +(g15 +(I0 +tS'b' +tRp56 +(I1 +(L1L +tg34 +I00 +S'\xbc;\xceA\xf6\xc7\xe2\xbf' +tbsg37 +I100 +sg38 +Nsbag2 +(g3 +g4 +NtRp57 +(dp58 +g7 +I00 +sS'C' +F1 +sg8 +I1 +sg9 +I0 +sg10 +I01 +sg11 +g12 +sg13 +g14 +(g15 +(I0 +tS'b' +tRp59 +(I1 +(L2L +tg18 +I00 +S'\x00\x00\x00\x00\x01\x00\x00\x00' +tbsg19 +g14 +(g15 +(I0 +tS'b' +tRp60 +(I1 +(L1L +tg21 +I00 +S'\x0e\x00\x00\x00' +tbsg22 +I1 +sg23 +g24 +sg25 +g26 +sg27 +Nsg28 +g29 +sg30 +I00 +sg31 +F0.0001 +sg32 +g14 +(g15 +(I0 +tS'b' +tRp61 +(I1 +(L1L +L40L +tg34 +I00 +S'\xd9\xfe\'\xeeW\xc4\x05@\x9f:b\xf1F\xa4\xf4?\xf28\xaf\xdc`\xd4\xc8?\xf5\xcb.\x85#\xa9\xe7\xbf6"e&G\xbb\x00@1\x833\xe4O\xa9\xf1\xbfEm\x1d\x8b\xb9\xdb\x0c\xc0\x83\xd6J\xe3Q|\xcd?\xa5\xd5?\x034"\xe5?\xd7d\x92\xaf\x8ev\xdc?\xe2U:|\xe0\xf9\xd1?\xcf>~0\xadm\xcb\xbf\x92D.n\xf6\'\xde?\x83\x90\x81\x07\xab\xa2\xd2\xbfn\x96L\x15W*\xfa?\xecC\xfaMtr\xdd\xbf4\r\xcew\x90\xad\xe7?\x10Wu\x18\x84"\xe2?\xe7m\xba\xb1pk\xd2?T\x84V\xe1V\x86\xbd\xbf\x82\x1b-\x0c\xb6H\xcf?\xe9F\x06{\x8dW\xc8?\xa5\xb5d\x0e\x1b\x11\xbe\xbf\x9f\xe4y4\xe5n\xd3?\xe0\xf8M\xb1\xff\xb5\xd9?)\xb1e\xc4\xb4\x02\xe0?\x04\xdb\x94\xe2\xfd\xa6\xd1?9\xa7\x04\x89\xff\x90\xbf?\xee\x8f\xab\x1d1\x1c\xc8\xbf&w\xf6X\xd0\x8f\xc1?\xf3\xdeI\xaa\xfc\x02\xb2?\xae\xbf\xa3Sj\xa0\xc1?\x83\xdb\xcf\rr\x0e\xec?\r\xe8\xb1k\xc9\x14\xe0?\x1b"j?o\xc8\xd2?\xdf9_W\xa7{\xcd?.\xbf&\xea\xa2\x1f\xda?\xbeno\xb3\xe8\xda\xb5?\x85&\xcbt\xd3r\xd0?\x04\xbb\xffur\xc1\xde?' +tbsg35 +g14 +(g15 +(I0 +tS'b' +tRp62 +(I1 +(L1L +tg34 +I00 +S'\x90\xe2;h\x9e\xbf\xe2\xbf' +tbsg37 +I100 +sg38 +Nsba. \ No newline at end of file diff --git a/python/predictor.py b/python/predictor.py new file mode 100644 index 0000000..9f21518 --- /dev/null +++ b/python/predictor.py @@ -0,0 +1,35 @@ +from sklearn.linear_model import LogisticRegression +try: + import cPickle as pickle +except: + import pickle + + +predictors = None + +def init(fname = './model.pkl'): + global predictors + f = open(fname, 'rb') + predictors = pickle.load(f) + f.close() + +def dump(): + global predictors + for i in predictors: + print i + print i.coef_ + + +def predict(index, features): + pp = predictors[index] + true_class = int(pp.classes_[1] == 1) + prob = pp.predict_proba([features])[0, true_class] + return prob + + +if __name__ == '__main__': + init('./model.pkl') + dump() + feature=[1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2] + print predict(len(predictors) - 1, feature) + diff --git a/src/Engine.cpp b/src/Engine.cpp index 49a17da..6d44f3c 100644 --- a/src/Engine.cpp +++ b/src/Engine.cpp @@ -2,6 +2,7 @@ #include #include "Engine.h" #include "Logger.h" +#include "PredictorWrapper.h" using namespace suanzi; @@ -42,8 +43,13 @@ void Engine::destroy() void Engine::setVideoSrc(VideoSrcType type, const std::string& url) { + PredictorWrapperPtr pp = PredictorWrapper::create("./python/model.pkl"); + + pp->dump(); + // videoSrc = url; - reader = VideoReaderFactory::createVideoReader(type, url); + //reader = VideoReaderFactory::createVideoReader(type, url); + } void Engine::run() diff --git a/src/Metrics.cpp b/src/Metrics.cpp index 7c451f1..49822f2 100644 --- a/src/Metrics.cpp +++ b/src/Metrics.cpp @@ -1,3 +1,4 @@ +#include "Logger.h" #include "Metrics.h" using namespace suanzi; @@ -11,3 +12,7 @@ Metrics::Metrics(const std::string& clf_path) } } + +void Metrics::similarity(const Patch& p1, const Patch& p2) +{ +} diff --git a/src/MultiTracker.cpp b/src/MultiTracker.cpp index 8df378f..d0e7048 100644 --- a/src/MultiTracker.cpp +++ b/src/MultiTracker.cpp @@ -1,9 +1,11 @@ #include "MultiTracker.h" #include "Metrics.h" #include +#include "hungarian.h" using namespace suanzi; using namespace cv; +using namespace Eigen; MultiTracker::MultiTracker(MetricsPtr m) : metrics(m) { @@ -24,12 +26,12 @@ TrackerPtr MultiTracker::createTracker(int id) void MultiTracker::addTracker(TrackerPtr t) { - trackers.insert(t); + trackers.push_back(t); } void MultiTracker::removeTracker(TrackerPtr t) { - trackers.erase(t); +// trackers.erase(t); } void MultiTracker::initNewTrackers(cv::Mat& iamge) @@ -41,35 +43,49 @@ void MultiTracker::correctTrackers(MetricsPtr m, Mat& image) { } - -void MultiTracker::update(unsigned int total, const Detection* d, const Mat& image) +void calculate_edistance() { +} - // correct_trackers - - - // - - - - - - +#define MaxCost 100000 +void MultiTracker::update(unsigned int total, const Detection* detections, const Mat& image) +{ + // correct_trackers + // Generate cost matrix + int row = trackers.size(); + int col = total; + MatrixXi cost_matrix = MatrixXi::Zero(row, col); + for (int i = 0; i < row; i++){ + for (int j = 0; j < col; j++){ + TrackerPtr tracker = trackers[i]; + Detection det = detections[j]; + + int cost = MaxCost; + + // TODO + cost_matrix(i, j) = cost; + } + } + + // assignment + VectorXi tracker_inds, bb_inds; + linear_sum_assignment(cost_matrix, tracker_inds, bb_inds); + + // handle the result + vector unmatched_trackers; + vector unmatched_detection; + for (int i = 0; i < row; i++){ + if (!(tracker_inds.array() == i).any()){ + unmatched_trackers.push_back(trackers[i]); + } + } + for(int j = 0; j < col; j++){ + if (!(bb_inds.array() == j).any()){ + unmatched_detection.push_back(detections[j]); + } + } - // Delete long lost trackers; -// for (auto& t : trackers){ -// if (t->status == TrackerStatus::Delete) -// trackers.erase(t); -// } -// - // Update trackers using kalman filter -// for(auto& t: trackers){ -// //t.bb_ltrb = -// } -// - // associate trackers with detections -// correctTrackers(this->metric, image); // create new trackers for new detections } diff --git a/src/PredictorWrapper.cpp b/src/PredictorWrapper.cpp new file mode 100644 index 0000000..8c5965f --- /dev/null +++ b/src/PredictorWrapper.cpp @@ -0,0 +1,104 @@ +#include "PredictorWrapper.h" +#include +#include + +namespace py = boost::python; + +using namespace std; +using namespace suanzi; + + +PredictorWrapperWPtr PredictorWrapper::instance; + +const static std::string PREDICTOR_PY_DIR = "./python"; + +static std::string parse_python_exception(); + + +PredictorWrapperPtr PredictorWrapper::create(const std::string& fname) +{ + //if (instance == nullptr){ + if (instance.lock()){ + //instance = new PredictorWrapper(fname); + return PredictorWrapperPtr(); + } + PredictorWrapperPtr ins (new PredictorWrapper(fname)); + instance = ins; + return ins; +} + +PredictorWrapper::PredictorWrapper(const std::string& fname) +{ + Py_Initialize(); + try{ + py::object main_module = py::import("__main__"); + py::object main_namespace = main_module.attr("__dict__"); + py::exec("import sys", main_namespace); + std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')"; + py::exec(cmd.c_str(), main_namespace); + //py::exec("sys.path.insert(0, '/home/debian/project/tracker/python')", main_namespace); + //py::exec("sys.path.insert(0, './python')", main_namespace); + py::object predictor_mod = py::import("predictor"); + py::object predictor_init = predictor_mod.attr("init"); + dump_func = predictor_mod.attr("dump"); + predict_func = predictor_mod.attr("predict"); + + predictor_init(fname.c_str()); + //predictor_dump(); + //py::exec("import predictor", main_namespace); + } catch (boost::python::error_already_set const &){ + std::string perror_str = parse_python_exception(); + std::cout << "Error in Python: " << perror_str << std::endl; + } +} + +static std::string parse_python_exception(){ + PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL; + // Fetch the exception info from the Python C API + PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr); + + // Fallback error + std::string ret("Unfetchable Python error"); + // If the fetch got a type pointer, parse the type into the exception string + if(type_ptr != NULL){ + py::handle<> h_type(type_ptr); + py::str type_pstr(h_type); + // Extract the string from the boost::python object + py::extract e_type_pstr(type_pstr); + // If a valid string extraction is available, use it + // otherwise use fallback + if(e_type_pstr.check()) + ret = e_type_pstr(); + else + ret = "Unknown exception type"; + } + // Do the same for the exception value (the stringification of the exception) + if(value_ptr != NULL){ + py::handle<> h_val(value_ptr); + py::str a(h_val); + py::extract returned(a); + if(returned.check()) + ret += ": " + returned(); + else + ret += std::string(": Unparseable Python error: "); + } + // Parse lines from the traceback using the Python traceback module + if(traceback_ptr != NULL){ + py::handle<> h_tb(traceback_ptr); + // Load the traceback module and the format_tb function + py::object tb(py::import("traceback")); + py::object fmt_tb(tb.attr("format_tb")); + // Call format_tb to get a list of traceback strings + py::object tb_list(fmt_tb(h_tb)); + // Join the traceback strings into a single string + py::object tb_str(py::str("\n").join(tb_list)); + // Extract the string, check the extraction, and fallback in necessary + py::extract returned(tb_str); + if(returned.check()) + ret += ": " + returned(); + else + ret += std::string(": Unparseable Python traceback"); + } + return ret; +} + diff --git a/src/hungarian.cpp b/src/hungarian.cpp index 85b49e5..7c1d8da 100644 --- a/src/hungarian.cpp +++ b/src/hungarian.cpp @@ -301,3 +301,15 @@ int step_six(Hungary& state) } return 4; } + +//////////////////////////////////////////////////////////////////////////////// +double distance_cosine(const VectorXd& u, const VectorXd& v) +{ + return (1 - u.dot(v) / std::sqrt(u.dot(u) * v.dot(v))); +} + +double distance_euclidean(const VectorXd& u, const VectorXd& v) +{ + VectorXd d = u - v; + return std::sqrt(d.dot(d)); +} diff --git a/test/TestHungarian.cpp b/test/TestHungarian.cpp index f20ec09..3d6173f 100644 --- a/test/TestHungarian.cpp +++ b/test/TestHungarian.cpp @@ -1,5 +1,6 @@ #include "hungarian.h" #include "gtest/gtest.h" +#include using namespace std; using namespace Eigen; @@ -44,3 +45,36 @@ TEST(Hungarian, 4x3) EXPECT_TRUE(expect_row_ind == row_ind); EXPECT_TRUE(expect_col_ind == col_ind); } + +TEST(Distance, consine) +{ + Vector3d u, v; + u << 1, 0, 0; + v << 0, 1, 0; + double d = distance_cosine(u, v); + EXPECT_DOUBLE_EQ(d, 1.0); + + u << 100, 0, 0; + v << 0, 1, 0; + d = distance_cosine(u, v); + EXPECT_DOUBLE_EQ(d, 1.0); + + u << 1, 1, 0; + v << 0, 1, 0; + d = distance_cosine(u, v); + EXPECT_TRUE(std::abs(d - 0.2928932) < 0.0001); +} + +TEST(Distance, euclidean) +{ + Vector3d u, v; + u << 1, 0, 0; + v << 0, 1, 0; + double d = distance_euclidean(u, v); + EXPECT_TRUE(std::abs(d - 1.41421356) < 0.0001); + + u << 1, 1, 0; + v << 0, 1, 0; + d = distance_euclidean(u, v); + EXPECT_DOUBLE_EQ(d, 1.0); +} -- 2.11.0