使用基本数组转换的内联 Weave(无 Blitz)¶
日期 | 2011-08-05(最后修改),2008-05-28(创建) |
---|
Python 和 Numpy 被设计用来表达对许多大小的传入数据透明地工作的通用语句。使用内联 Weave 和 Blitz 转换可以显着加快许多数值运算(例如,一系列数组的加法),因为在某些方面它绕过了通用性。如何在保持通用性的同时使用内联 C 代码加快算法速度?Numpy 提供的一个工具是**迭代器**。因为迭代器会为你跟踪内存索引,所以它的操作与 Python 本身中的迭代概念非常类似。你可以编写 C 循环,它简单地说“从串行对象(!PyArrayObject)中获取下一个元素,并对其进行操作,直到没有更多元素。”
这是一个非常简单的多维迭代器示例,以及它们“广播”兼容形状数组的能力。它表明完全不知道维度的相同代码可以根据广播规则实现完全不同的计算。在本例中,我假设**a** 至少与**b** 具有相同数量的维度。重要的是要知道,**a** 的 weave 数组转换让你在 C++ 中访问:*py_a -- !PyObject * *a_array -- !PyArrayObject * *a -- (c-type *) py_array->data
import numpy as npy
from scipy.weave import inline
def multi_iter_example():
a = npy.ones((4,4), npy.float64)
# for the sake of driving home the "dynamic code" approach...
dtype2ctype = {
npy.dtype(npy.float64): 'double',
npy.dtype(npy.float32): 'float',
npy.dtype(npy.int32): 'int',
npy.dtype(npy.int16): 'short',
}
dt = dtype2ctype.get(a.dtype)
# this code does a = a*b inplace, broadcasting b to fit the shape of a
code = \
"""
%s *p1, *p2;
PyObject *itr;
itr = PyArray_MultiIterNew(2, a_array, b_array);
while(PyArray_MultiIter_NOTDONE(itr)) {
p1 = (%s *) PyArray_MultiIter_DATA(itr, 0);
p2 = (%s *) PyArray_MultiIter_DATA(itr, 1);
*p1 = (*p1) * (*p2);
PyArray_MultiIter_NEXT(itr);
}
""" % (dt, dt, dt)
b = npy.arange(4, dtype=a.dtype)
print '\n A B '
print a, b
# this reshaping is redundant, it would be the default broadcast
b.shape = (1,4)
inline(code, ['a', 'b'])
print "\ninline version of a*b[None,:],"
print a
a = npy.ones((4,4), npy.float64)
b = npy.arange(4, dtype=a.dtype)
b.shape = (4,1)
inline(code, ['a', 'b'])
print "\ninline version of a*b[:,None],"
print a
在 iterators_example.py 和 iterators.py 中还有另外两个迭代器应用。
深入了解“内联”方法¶
inline 的 文档字符串 非常庞大,它表明在集成您的内联代码时,支持各种编译选项。我利用这一点,使一些专门的 FFTW 调用变得更加简单,并且仅用几行代码就添加了对内联 FFT 的支持。在这个例子中,我读取了一个纯 C 代码文件,并将其用作我内联语句中的 support_code。我还使用 Numpy 的 distutils 中的一个工具来定位我的 FFTW 库和头文件。
import numpy as N
from scipy.weave import inline
from os.path import join, split
from numpy.distutils.system_info import get_info
fft1_code = \
"""
char *i, *o;
i = (char *) a;
o = inplace ? i : (char *) b;
if(isfloat) {
cfft1d(reinterpret_cast<fftwf_complex*>(i),
reinterpret_cast<fftwf_complex*>(o),
xdim, len_array, direction, shift);
} else {
zfft1d(reinterpret_cast<fftw_complex*>(i),
reinterpret_cast<fftw_complex*>(o),
xdim, len_array, direction, shift);
}
"""
extra_code = open(join(split(__file__)[0],'src/cmplx_fft.c')).read()
fftw_info = get_info('fftw3')
def fft1(a, shift=True, inplace=False):
if inplace:
_fft1_work(a, -1, shift, inplace)
else:
return _fft1_work(a, -1, shift, inplace)
def ifft1(a, shift=True, inplace=False):
if inplace:
_fft1_work(a, +1, shift, inplace)
else:
return _fft1_work(a, +1, shift, inplace)
def _fft1_work(a, direction, shift, inplace):
# to get correct C-code, b always must be an array (but if it's
# not being used, it can be trivially small)
b = N.empty_like(a) if not inplace else N.array([1j], a.dtype)
inplace = 1 if inplace else 0
shift = 1 if shift else 0
isfloat = 1 if a.dtype.itemsize==8 else 0
len_array = N.product(a.shape)
xdim = a.shape[-1]
inline(fft1_code, ['a', 'b', 'isfloat', 'inplace',
'len_array', 'xdim', 'direction', 'shift'],
support_code=extra_code,
headers=['<fftw3.h>'],
libraries=['fftw3', 'fftw3f'],
include_dirs=fftw_info['include_dirs'],
library_dirs=fftw_info['library_dirs'],
compiler='gcc')
if not inplace:
return b
此代码可在 [附件:fftmod.tar.gz fftmod.tar.gz] 中找到。
章节作者: Unknown[6], DavidCooke, TravisOliphant, Unknown[159], Unknown[160], Unknown[161], Unknown[162], FernandoPerez, Unknown[17], PauliVirtanen
附件