Multidot

日期2007-03-24(最后修改),2007-03-24(创建)

矩阵乘法函数 numpy.dot() 仅接受两个参数。这意味着要将两个以上的数组相乘,您最终会得到嵌套的函数调用,这些调用难以阅读

在 [ ]
dot(dot(dot(a,b),c),d)

与中缀表示法相比,您只需编写

在 [ ]
a*b*c*d

有两种方法可以定义一个类似于 dot 但接受两个以上参数的 'mdot' 函数。使用其中一种方法,您可以将上述表达式写成

在 [ ]
mdot(a,b,c,d)

使用 reduce

最简单的方法是使用 reduce。

在 [ ]
def mdot(*args):
    return reduce(numpy.dot, args)

或者使用等效的循环(据说是 Py3K 的首选风格

在 [ ]
def mdot(*args):
    ret = args[0]
    for a in args[1:]:
        ret = dot(ret,a)
    return ret

这将始终为您提供从左到右的结合性,即表达式被解释为 `(((a*b)*c)*d)`。

您还可以创建一个循环的右结合版本

在 [ ]
def mdotr(*args):
    ret = args[-1]
    for a in reversed(args[:-1]):
        ret = dot(a,ret)
    return ret

它计算为 `(a*(b*(c*d)))`。但有时您可能希望有更精细的控制,因为矩阵乘法的执行顺序会对性能产生重大影响。下一个版本提供了这种控制。

控制评估顺序

如果我们愿意牺牲 Numpy 将元组视为数组的能力,我们可以使用元组作为分组构造。这个版本的 `mdot` 允许使用以下语法

在 [ ]
mdot(a,((b,c),d))

来控制成对 `dot` 调用的执行顺序。

在 [ ]
import types
import numpy
def mdot(*args):
   """Multiply all the arguments using matrix product rules.
   The output is equivalent to multiplying the arguments one by one
   from left to right using dot().
   Precedence can be controlled by creating tuples of arguments,
   for instance mdot(a,((b,c),d)) multiplies a (a*((b*c)*d)).
   Note that this means the output of dot(a,b) and mdot(a,b) will differ if
   a or b is a pure tuple of numbers.
   """
   if len(args)==1:
       return args[0]
   elif len(args)==2:
       return _mdot_r(args[0],args[1])
   else:
       return _mdot_r(args[:-1],args[-1])

def _mdot_r(a,b):
   """Recursive helper for mdot"""
   if type(a)==types.TupleType:
       if len(a)>1:
           a = mdot(*a)
       else:
           a = a[0]
   if type(b)==types.TupleType:
       if len(b)>1:
           b = mdot(*b)
       else:
           b = b[0]
   return numpy.dot(a,b)

乘法

请注意,元素级乘法函数 `numpy.multiply` 与 `numpy.dot` 具有相同的两个参数限制。可以为 multiply 定义完全相同的广义形式。

左结合版本

在 [ ]
def mmultiply(*args):
    return reduce(numpy.multiply, args)
在 [ ]
def mmultiply(*args):
    ret = args[0]
    for a in args[1:]:
        ret = multiply(ret,a)
    return ret

右结合版本

在 [ ]
def mmultiplyr(*args):
    ret = args[-1]
    for a in reversed(args[:-1]):
        ret = multiply(a,ret)
    return ret

使用元组控制求值顺序的版本

在 [ ]
import types
import numpy
def mmultiply(*args):
   """Multiply all the arguments using elementwise product.
   The output is equivalent to multiplying the arguments one by one
   from left to right using multiply().
   Precedence can be controlled by creating tuples of arguments,
   for instance mmultiply(a,((b,c),d)) multiplies a (a*((b*c)*d)).
   Note that this means the output of multiply(a,b) and mmultiply(a,b) will differ if
   a or b is a pure tuple of numbers.
   """
   if len(args)==1:
       return args[0]
   elif len(args)==2:
       return _mmultiply_r(args[0],args[1])
   else:
       return _mmultiply_r(args[:-1],args[-1])

def _mmultiply_r(a,b):
   """Recursive helper for mmultiply"""
   if type(a)==types.TupleType:
       if len(a)>1:
           a = mmultiply(*a)
       else:
           a = a[0]
   if type(b)==types.TupleType:
       if len(b)>1:
           b = mmultiply(*b)
       else:
           b = b[0]
   return numpy.multiply(a,b)

章节作者:BillBaxter