#include "Logger.h"
#include <iostream>
#include <thread>
+#include "PyWrapper.h"
namespace py = boost::python;
using namespace suanzi;
-PredictorWrapperWPtr PredictorWrapper::instance;
-
-const static std::string PREDICTOR_PY_DIR = "./python";
const static std::string TAG = "PredictorWrapper";
-static std::string parse_python_exception();
-
template <class T>
-boost::python::list toPythonList(std::vector<T> vector) {
+boost::python::list toPythonList(const std::vector<T>& v) {
typename std::vector<T>::iterator iter;
boost::python::list list;
- for (iter = vector.begin(); iter != vector.end(); ++iter) {
- list.append(*iter);
+ for(const auto& vv : v){
+ list.append(vv);
}
return list;
}
-PredictorWrapperPtr PredictorWrapper::create(const std::string& fname)
-{
- if (instance.lock()){
- return PredictorWrapperPtr();
- }
- PredictorWrapperPtr ins (new PredictorWrapper(fname));
- instance = ins;
- return ins;
-}
+#define CALL_WITH_EXCEPTION(expr, ret_val) \
+ try{ \
+ py::object py_ret = expr; \
+ ret_val = py::extract<decltype(ret_val)>(py_ret); \
+ }catch(boost::python::error_already_set const &){ \
+ LOG_ERROR(TAG, PyWrapper::parse_python_exception()); \
+ throw runtime_error("Python error"); \
+ } \
-void PredictorWrapper::dump()
+
+PredictorWrapper::PredictorWrapper(const string& module, const string& pydir)
{
- LOG_DEBUG(TAG, "dump");
- std::string ss = "";
+ PyWrapperPtr interpreter = PyWrapper::getInstance(pydir);
try{
- py::object ret = this->dump_func();
- ss = py::extract<std::string>(ret);
+ m_module = interpreter->import(module);
} catch (boost::python::error_already_set const &){
- std::string perror_str = parse_python_exception();
- LOG_ERROR(TAG, "Error in Python: " + perror_str)
+ LOG_ERROR(TAG, PyWrapper::parse_python_exception());
}
- LOG_DEBUG(TAG, ss);
}
-double PredictorWrapper::predict(int index, std::vector<double> ff)
+bool PredictorWrapper::load(const string& fname)
{
- LOG_DEBUG(TAG, "predict");
- try{
- this->predict_func(index, toPythonList(ff));
- } catch (boost::python::error_already_set const &){
- std::string perror_str = parse_python_exception();
- LOG_ERROR(TAG, "Error in Python: " + perror_str)
- }
- return 0.1;
+ LOG_DEBUG(TAG, "load " + fname);
+ py::object func = m_module.attr("init");
+ bool ret = true;
+ CALL_WITH_EXCEPTION(func(fname.c_str()), ret);
+ return ret;
}
-PredictorWrapper::PredictorWrapper(const std::string& fname)
+void PredictorWrapper::dump()
{
- Py_Initialize();
- try{
- py::object main_module = py::import("__main__");
- py::object main_namespace = main_module.attr("__dict__");
- py::exec("import sys", main_namespace);
- std::string cmd = "sys.path.insert(0, '" + PREDICTOR_PY_DIR + "')";
- py::exec(cmd.c_str(), main_namespace);
- py::exec("import signal", main_namespace);
- py::exec("signal.signal(signal.SIGINT, signal.SIG_DFL)", main_namespace);
- py::object predictor_mod = py::import("predictor");
- py::object predictor_init = predictor_mod.attr("init");
- dump_func = predictor_mod.attr("dump");
- predict_func = predictor_mod.attr("predict");
- predictor_init(fname.c_str());
- } catch (boost::python::error_already_set const &){
- std::string perror_str = parse_python_exception();
- LOG_ERROR(TAG, "Error in Python: " + perror_str)
- }
+ LOG_DEBUG(TAG, "dump");
+ py::object func = m_module.attr("dump");
+ string ret = "";
+ CALL_WITH_EXCEPTION(func(), ret);
+ LOG_DEBUG(TAG, ret);
}
-static std::string parse_python_exception(){
- PyObject *type_ptr = NULL, *value_ptr = NULL, *traceback_ptr = NULL;
- // Fetch the exception info from the Python C API
- PyErr_Fetch(&type_ptr, &value_ptr, &traceback_ptr);
-
- // Fallback error
- std::string ret("Unfetchable Python error");
- // If the fetch got a type pointer, parse the type into the exception string
- if(type_ptr != NULL){
- py::handle<> h_type(type_ptr);
- py::str type_pstr(h_type);
- // Extract the string from the boost::python object
- py::extract<std::string> e_type_pstr(type_pstr);
- // If a valid string extraction is available, use it
- // otherwise use fallback
- if(e_type_pstr.check())
- ret = e_type_pstr();
- else
- ret = "Unknown exception type";
- }
- // Do the same for the exception value (the stringification of the exception)
- if(value_ptr != NULL){
- py::handle<> h_val(value_ptr);
- py::str a(h_val);
- py::extract<std::string> returned(a);
- if(returned.check())
- ret += ": " + returned();
- else
- ret += std::string(": Unparseable Python error: ");
- }
- // Parse lines from the traceback using the Python traceback module
- if(traceback_ptr != NULL){
- py::handle<> h_tb(traceback_ptr);
- // Load the traceback module and the format_tb function
- py::object tb(py::import("traceback"));
- py::object fmt_tb(tb.attr("format_tb"));
- // Call format_tb to get a list of traceback strings
- py::object tb_list(fmt_tb(h_tb));
- // Join the traceback strings into a single string
- py::object tb_str(py::str("\n").join(tb_list));
- // Extract the string, check the extraction, and fallback in necessary
- py::extract<std::string> returned(tb_str);
- if(returned.check())
- ret += ": " + returned();
- else
- ret += std::string(": Unparseable Python traceback");
- }
- return ret;
-}
-
+double PredictorWrapper::predict(int index, const vector<double>& features)
+{
+ py::object func = m_module.attr("predict");
+ double rr = 0;
+ CALL_WITH_EXCEPTION(func(index, toPythonList(features)), rr);
+ return rr;
+}