boost-python ではじめる大規模機械学習(3)- Python モジュールの呼び出し

あらすじ

boost-python を使用して、Python と C 言語両方を活用する方法を説明しています。
前記事では、Python・C 言語間で簡単なオブジェクトを受け渡す方法を解説しました。
本記事では、任意の Python モジュールを C 言語から呼び出す方法を解説します。

準備

以降、色々と SciPy の機能を使うことがあります。このとき、以下のコードを拡張モジュール側に書き込む必要があります。

#include <numpy/arrayobject.h>

BOOST_PYTHON_MODULE(hello)
{
    numeric::array::set_module_and_type("numpy", "ndarray");
    import_array();
}

理由は SciPy のリファレンスに書いてあるハズ。

Python モジュールの読み込み

import 文を使用することで、Python モジュールを読み込むことができます。
各モジュールの属性は attr 関数で呼び出すことができます。
おどろくほどシンプルな動作です。

#include <boost/python.hpp>
#include <numpy/arrayobject.h>
using namespace boost::python;
static object scipy = import("scipy");
static object float64 = scipy.attr("float64");
static object uint8 = scipy.attr("uint8");
static object zeros = scipy.attr("zeros");

Python 関数の呼び出し

非常にシンプルです。迷うところがありません。

object get_float()
{
    return float64;
}

object make_float_zeros(int nx, int ny)
{
    return zeros(make_tuple(nx, ny), float64);
}

object make_byte_zeros(int nx, int ny)
{
    return zeros(make_tuple(nx, ny), uint8);
}

動作テスト

hello.cpp
#include <boost/python.hpp>
#include <numpy/arrayobject.h>
using namespace boost::python;
static object scipy = import("scipy");
static object float64 = scipy.attr("float64");
static object uint8 = scipy.attr("uint8");
static object zeros = scipy.attr("zeros");

object get_float()
{
    return float64;
}

object make_float_zeros(int nx, int ny)
{
    return zeros(make_tuple(nx, ny), float64);
}

object make_byte_zeros(int nx, int ny)
{
    return zeros(make_tuple(nx, ny), uint8);
}

BOOST_PYTHON_MODULE(hello)
{
    numeric::array::set_module_and_type("numpy", "ndarray");
    def("get_float", get_float);
    def("make_float_zeros", make_float_zeros);
    def("make_byte_zeros", make_byte_zeros);
    import_array();
}
実行結果
$ python
>>> import hello
>>> hello.get_float()
<type 'numpy.float64'>
>>> hello.make_float_zeros(2, 3)
array([[ 0.,  0.,  0.],
       [ 0.,  0.,  0.]])
>>> hello.make_byte_zeros(2, 4)
array([[0, 0, 0, 0],
       [0, 0, 0, 0]], dtype=uint8)
>>> 

次回は配列の内容にアクセスする方法を解説します。