Fix nan issue in features
[trackerpp.git] / src / PredictorWrapper.cpp
index ebe7579..77d99af 100644 (file)
@@ -3,6 +3,7 @@
 #include "Logger.h"
 #include <iostream>
 #include <thread>
+#include "PyWrapper.h"
 
 namespace py = boost::python;
 
@@ -10,12 +11,8 @@ using namespace std;
 using namespace suanzi;
 
 
-PredictorWrapperWPtr PredictorWrapper::instance;
-
 const static std::string TAG = "PredictorWrapper";
 
-static std::string parse_python_exception();
-
 template <class T>
 boost::python::list toPythonList(const std::vector<T>& v) {
     typename std::vector<T>::iterator iter;
@@ -26,114 +23,48 @@ boost::python::list toPythonList(const std::vector<T>& v) {
     return list;
 }
 
-PredictorWrapperPtr PredictorWrapper::create(const std::string& python_dir, const std::string& model_path)
-{
-    if (instance.lock()){
-        return PredictorWrapperPtr();
-    }
-    PredictorWrapperPtr ins (new PredictorWrapper(python_dir, model_path));
-    instance = ins;
-    return ins;
-}
+#define CALL_WITH_EXCEPTION(expr, ret_val)                      \
+    try{                                                        \
+        py::object py_ret = expr;                               \
+        ret_val = py::extract<decltype(ret_val)>(py_ret);       \
+    }catch(boost::python::error_already_set const &){           \
+        LOG_ERROR(TAG, PyWrapper::parse_python_exception());    \
+        throw runtime_error("Python error");                    \
+    }                                                           \
 
-void PredictorWrapper::dump()
+
+PredictorWrapper::PredictorWrapper(const string& module, const string& pydir)
 {
-    LOG_DEBUG(TAG, "dump");
-    std::string ss = "";
+    PyWrapperPtr interpreter = PyWrapper::getInstance(pydir);
     try{
-        py::object ret = this->dump_func();
-        ss = py::extract<std::string>(ret);
+        m_module = interpreter->import(module);
     } catch (boost::python::error_already_set const &){
-        std::string perror_str = parse_python_exception();
-        LOG_ERROR(TAG, "Error in Python: " + perror_str)
+        LOG_ERROR(TAG, PyWrapper::parse_python_exception());
     }
-    LOG_DEBUG(TAG, ss);
 }
 
-double PredictorWrapper::predict(int index, const std::vector<double>& ff)
+bool PredictorWrapper::load(const string& fname)
 {
-    LOG_DEBUG(TAG, "predict");
-    py::object ret;
-    try{
-        ret = this->predict_func(index, toPythonList(ff));
-    } catch (boost::python::error_already_set const &){
-        std::string perror_str = parse_python_exception();
-        LOG_ERROR(TAG, "Error in Python: " + perror_str)
-    }
-    double rr = py::extract<double>(ret);
-    LOG_DEBUG(TAG, "return: " + std::to_string(rr));
-    return rr; 
+    LOG_DEBUG(TAG, "load " + fname);
+    py::object func = m_module.attr("init");
+    bool ret = true;
+    CALL_WITH_EXCEPTION(func(fname.c_str()), ret);
+    return ret;
 }
 
-PredictorWrapper::PredictorWrapper(const std::string& py_dir, const std::string& model_path)
+void PredictorWrapper::dump()
 {
-    Py_Initialize();
-    try{
-        py::object main_module = py::import("__main__");
-        py::object main_namespace = main_module.attr("__dict__");
-        py::exec("import sys", main_namespace);
-        std::string cmd = "sys.path.insert(0, '" + py_dir + "')";
-        py::exec(cmd.c_str(), main_namespace);
-        py::exec("import signal", main_namespace);
-        py::exec("signal.signal(signal.SIGINT, signal.SIG_DFL)", main_namespace);
-        py::object predictor_mod = py::import("predictor");
-        py::object predictor_init = predictor_mod.attr("init");
-        dump_func = predictor_mod.attr("dump");
-        predict_func = predictor_mod.attr("predict");
-        predictor_init(model_path.c_str());
-    } catch (boost::python::error_already_set const &){
-        std::string perror_str = parse_python_exception();
-        LOG_ERROR(TAG, "Error in Python: " + perror_str)
-    }
+    LOG_DEBUG(TAG, "dump");
+    py::object func = m_module.attr("dump");
+    string ret = "";
+    CALL_WITH_EXCEPTION(func(), ret);
+    LOG_DEBUG(TAG, ret);
 }
 
-static std::string parse_python_exception(){  
-    PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL;  
-    // Fetch the exception info from the Python C API  
-    PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr);  
-  
-    // Fallback error  
-    std::string ret("Unfetchable Python error");  
-    // If the fetch got a type pointer, parse the type into the exception string  
-    if(type_ptr != NULL){  
-        py::handle<> h_type(type_ptr);  
-        py::str type_pstr(h_type);  
-        // Extract the string from the boost::python object  
-        py::extract<std::string> e_type_pstr(type_pstr);  
-        // If a valid string extraction is available, use it   
-        //  otherwise use fallback  
-        if(e_type_pstr.check())  
-            ret = e_type_pstr();  
-        else  
-            ret = "Unknown exception type";  
-    }  
-    // Do the same for the exception value (the stringification of the exception)  
-    if(value_ptr != NULL){  
-        py::handle<> h_val(value_ptr);  
-        py::str a(h_val);  
-        py::extract<std::string> returned(a);  
-        if(returned.check())  
-            ret +=  ": " + returned();  
-        else  
-            ret += std::string(": Unparseable Python error: ");  
-    }  
-    // Parse lines from the traceback using the Python traceback module  
-    if(traceback_ptr != NULL){  
-        py::handle<> h_tb(traceback_ptr);  
-        // Load the traceback module and the format_tb function  
-        py::object tb(py::import("traceback"));  
-        py::object fmt_tb(tb.attr("format_tb"));  
-        // Call format_tb to get a list of traceback strings  
-        py::object tb_list(fmt_tb(h_tb));  
-        // Join the traceback strings into a single string  
-        py::object tb_str(py::str("\n").join(tb_list));  
-        // Extract the string, check the extraction, and fallback in necessary  
-        py::extract<std::string> returned(tb_str);  
-        if(returned.check())  
-            ret += ": " + returned();  
-        else  
-            ret += std::string(": Unparseable Python traceback");  
-    }  
-    return ret;  
-}  
-
+double PredictorWrapper::predict(int index, const vector<double>& features)
+{
+    py::object func = m_module.attr("predict");
+    double rr = 0;
+    CALL_WITH_EXCEPTION(func(index, toPythonList(features)), rr);
+    return rr;
+}