#include "PredictorWrapper.h"
#include <string>
+#include "Logger.h"
#include <iostream>
+#include <thread>
+#include "PyWrapper.h"
namespace py = boost::python;
using namespace suanzi;
-PredictorWrapperWPtr PredictorWrapper::instance;
+const static std::string TAG = "PredictorWrapper";
-const static std::string PREDICTOR_PY_DIR = "./python";
+template <class T>
+boost::python::list toPythonList(const std::vector<T>& v) {
+ typename std::vector<T>::iterator iter;
+ boost::python::list list;
+ for(const auto& vv : v){
+ list.append(vv);
+ }
+ return list;
+}
-static std::string parse_python_exception();
+#define CALL_WITH_EXCEPTION(expr, ret_val) \
+ try{ \
+ py::object py_ret = expr; \
+ ret_val = py::extract<decltype(ret_val)>(py_ret); \
+ }catch(boost::python::error_already_set const &){ \
+ LOG_ERROR(TAG, PyWrapper::parse_python_exception()); \
+ throw runtime_error("Python error"); \
+ } \
-PredictorWrapperPtr PredictorWrapper::create(const std::string& fname)
+PredictorWrapper::PredictorWrapper(const string& module, const string& pydir)
{
- //if (instance == nullptr){
- if (instance.lock()){
- //instance = new PredictorWrapper(fname);
- return PredictorWrapperPtr();
+ PyWrapperPtr interpreter = PyWrapper::getInstance(pydir);
+ try{
+ m_module = interpreter->import(module);
+ } catch (boost::python::error_already_set const &){
+ LOG_ERROR(TAG, PyWrapper::parse_python_exception());
}
- PredictorWrapperPtr ins (new PredictorWrapper(fname));
- instance = ins;
- return ins;
}
-PredictorWrapper::PredictorWrapper(const std::string& fname)
+bool PredictorWrapper::load(const string& fname)
{
- Py_Initialize();
- try{
- py::object main_module = py::import("__main__");
- py::object main_namespace = main_module.attr("__dict__");
- py::exec("import sys", main_namespace);
- std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')";
- py::exec(cmd.c_str(), main_namespace);
- //py::exec("sys.path.insert(0, '/home/debian/project/tracker/python')", main_namespace);
- //py::exec("sys.path.insert(0, './python')", main_namespace);
- py::object predictor_mod = py::import("predictor");
- py::object predictor_init = predictor_mod.attr("init");
- dump_func = predictor_mod.attr("dump");
- predict_func = predictor_mod.attr("predict");
-
- predictor_init(fname.c_str());
- //predictor_dump();
- //py::exec("import predictor", main_namespace);
- } catch (boost::python::error_already_set const &){
- std::string perror_str = parse_python_exception();
- std::cout << "Error in Python: " << perror_str << std::endl;
- }
+ LOG_DEBUG(TAG, "load " + fname);
+ py::object func = m_module.attr("init");
+ bool ret = true;
+ CALL_WITH_EXCEPTION(func(fname.c_str()), ret);
+ return ret;
}
-static std::string parse_python_exception(){
- PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL;
- // Fetch the exception info from the Python C API
- PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr);
-
- // Fallback error
- std::string ret("Unfetchable Python error");
- // If the fetch got a type pointer, parse the type into the exception string
- if(type_ptr != NULL){
- py::handle<> h_type(type_ptr);
- py::str type_pstr(h_type);
- // Extract the string from the boost::python object
- py::extract<std::string> e_type_pstr(type_pstr);
- // If a valid string extraction is available, use it
- // otherwise use fallback
- if(e_type_pstr.check())
- ret = e_type_pstr();
- else
- ret = "Unknown exception type";
- }
- // Do the same for the exception value (the stringification of the exception)
- if(value_ptr != NULL){
- py::handle<> h_val(value_ptr);
- py::str a(h_val);
- py::extract<std::string> returned(a);
- if(returned.check())
- ret += ": " + returned();
- else
- ret += std::string(": Unparseable Python error: ");
- }
- // Parse lines from the traceback using the Python traceback module
- if(traceback_ptr != NULL){
- py::handle<> h_tb(traceback_ptr);
- // Load the traceback module and the format_tb function
- py::object tb(py::import("traceback"));
- py::object fmt_tb(tb.attr("format_tb"));
- // Call format_tb to get a list of traceback strings
- py::object tb_list(fmt_tb(h_tb));
- // Join the traceback strings into a single string
- py::object tb_str(py::str("\n").join(tb_list));
- // Extract the string, check the extraction, and fallback in necessary
- py::extract<std::string> returned(tb_str);
- if(returned.check())
- ret += ": " + returned();
- else
- ret += std::string(": Unparseable Python traceback");
- }
- return ret;
-}
+void PredictorWrapper::dump()
+{
+ LOG_DEBUG(TAG, "dump");
+ py::object func = m_module.attr("dump");
+ string ret = "";
+ CALL_WITH_EXCEPTION(func(), ret);
+ LOG_DEBUG(TAG, ret);
+}
+double PredictorWrapper::predict(int index, const vector<double>& features)
+{
+ py::object func = m_module.attr("predict");
+ double rr = 0;
+ CALL_WITH_EXCEPTION(func(index, toPythonList(features)), rr);
+ return rr;
+}