08月12, 2020

6. Long Object 的乘法

整数的基本运算

上一节讲到,在 PyLong_Type 中定义了整数类型的各种属性,比如整数类型的名称 “int”。整数对象最常用的是一些数学运算,整数对象当然也是支持这些方法的,它们定义在:

// Object/longobject.c:5632
PyTypeObject PyLong_Type = {
    PyVarObject_HEAD_INIT(&PyType_Type, 0)
    "int",                                      /* tp_name */
    offsetof(PyLongObject, ob_digit),           /* tp_basicsize */
    sizeof(digit),                              /* tp_itemsize */
    0,                                          /* tp_dealloc */
    0,                                          /* tp_vectorcall_offset */
    0,                                          /* tp_getattr */
    0,                                          /* tp_setattr */
    0,                                          /* tp_as_async */
    long_to_decimal_string,                     /* tp_repr */
    &long_as_number,                            /* tp_as_number */
    /* 
    ************忽略了一些代码************
    */
    PyObject_Del,                               /* tp_free */
};

注意其中 &long_as_number 一行,它是一个函数指针数组,其中定义了整数的一系列数学方法,我们在用整数对象进行数学计算的时候,实质上调用的是这个数组里的指针。

// Object/longobject.c:5595,就在 PyLong_Type 定义上方
static PyNumberMethods long_as_number = {
    (binaryfunc)long_add,       /*nb_add*/
    (binaryfunc)long_sub,       /*nb_subtract*/
    (binaryfunc)long_mul,       /*nb_multiply*/
    long_mod,                   /*nb_remainder*/
    long_divmod,                /*nb_divmod*/
    long_pow,                   /*nb_power*/
    (unaryfunc)long_neg,        /*nb_negative*/
    long_long,                  /*tp_positive*/
    (unaryfunc)long_abs,        /*tp_absolute*/
    (inquiry)long_bool,         /*tp_bool*/
    (unaryfunc)long_invert,     /*nb_invert*/
    long_lshift,                /*nb_lshift*/
    long_rshift,                /*nb_rshift*/
    long_and,                   /*nb_and*/
    long_xor,                   /*nb_xor*/
    long_or,                    /*nb_or*/
    long_long,                  /*nb_int*/
    0,                          /*nb_reserved*/
    long_float,                 /*nb_float*/
    0,                          /* nb_inplace_add */
    0,                          /* nb_inplace_subtract */
    0,                          /* nb_inplace_multiply */
    0,                          /* nb_inplace_remainder */
    0,                          /* nb_inplace_power */
    0,                          /* nb_inplace_lshift */
    0,                          /* nb_inplace_rshift */
    0,                          /* nb_inplace_and */
    0,                          /* nb_inplace_xor */
    0,                          /* nb_inplace_or */
    long_div,                   /* nb_floor_divide */
    long_true_divide,           /* nb_true_divide */
    0,                          /* nb_inplace_floor_divide */
    0,                          /* nb_inplace_true_divide */
    long_long,                  /* nb_index */
};

通过结构体成员的命名就可以大致看出这些函数的作用,注意到其中一些方法没有实现,被设置为 0,这些是目前整数对象不支持的方法。绝大多数函数功能都不是很复杂,这里以比较复杂的乘法稍作解释。

三种的乘法的实现

第一种,也是最简单的,对于那些大小在 1 digit 的整数,直接通过 a * b 计算,一条 CPU 指令就可以得到结果。

第二种稍微复杂一点,对于那些不能在一个寄存器宽度内计算长整数,可以采用算式模拟(gradeschool math)。其原理非常简单,需要回顾小学的乘法竖式。假如 a = 123, b = 45, 计算 a * b.

img

将两个长整数乘法转换为一个长整数与 1 位整数按位相乘,结果求和的过程。这个方法不仅在 10 进制下,其他进制下也是成立的。

第三种是 Python 实现用用到的 Karatsuba 算法,这是一种快速实现长整数乘法的算法。其基本原理是将长整数按照一个基数 X 分解为 high、low 两部分,对于两个长整数 a,b 有:

$$ a b = a_{hi} b_{hi} X X + (k - a_{hi} b_{hi} - a_{lo} b_{lo}) X + a_{lo} b_{lo} $$

其中 $k = (a_{hi} + a_{lo}) * (b_{hi} + b_{lo})$ ,这样就把一个长整数乘法转换为:

  • $a_{hi} * b_{hi}$
  • $(a_{hi} + a_{lo}) * (b_{hi} + b_{lo})$
  • $a_{lo} * b_{lo}$

三个长度更短的整数运算,利用递归思想,长整数的运算很快被转换为一系列短整数乘法。

整数乘法的优化

Python 在实现整数乘法时做了三个优化,

第一个:

// Object/longobject.c:3551
static PyObject *
long_mul(PyLongObject *a, PyLongObject *b)
{
    PyLongObject *z;

    CHECK_BINOP(a, b);

    /* fast path for single-digit multiplication */
    if (Py_ABS(Py_SIZE(a)) <= 1 && Py_ABS(Py_SIZE(b)) <= 1) {
        stwodigits v = (stwodigits)(MEDIUM_VALUE(a)) * MEDIUM_VALUE(b);
        return PyLong_FromLongLong((long long)v);
    }

    z = k_mul(a, b);
    /* Negate if exactly one of the inputs is negative. */
    if (((Py_SIZE(a) ^ Py_SIZE(b)) < 0) && z) {
        _PyLong_Negate(&z);
        if (z == NULL)
            return NULL;
    }
    return (PyObject *)z;
}

如果整数的长度只有一个 digit,也就是 30 bits,那么就可以走捷径,直接用 C 语言内置乘法计算结果,仅从计算效率上来说,与 C 语言一样。

如果整数长度大于 1 个 digit,那么就去 k_mul 函数:

/* Karatsuba multiplication.  Ignores the input signs, and returns the
 * absolute value of the product (or NULL if error).
 * See Knuth Vol. 2 Chapter 4.3.3 (Pp. 294-295).
 */
static PyLongObject *
k_mul(PyLongObject *a, PyLongObject *b)
{
    Py_ssize_t asize = Py_ABS(Py_SIZE(a));
    Py_ssize_t bsize = Py_ABS(Py_SIZE(b));
    PyLongObject *ah = NULL;
    PyLongObject *al = NULL;
    PyLongObject *bh = NULL;
    PyLongObject *bl = NULL;
    PyLongObject *ret = NULL;
    PyLongObject *t1, *t2, *t3;
    Py_ssize_t shift;           /* the number of digits we split off */
    Py_ssize_t i;

    /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl
     * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh  + ah*bh + al*bl
     * Then the original product is
     *     ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl
     * By picking X to be a power of 2, "*X" is just shifting, and it's
     * been reduced to 3 multiplies on numbers half the size.
     */

    /* We want to split based on the larger number; fiddle so that b
     * is largest.
     */
    if (asize > bsize) {
        t1 = a;
        a = b;
        b = t1;

        i = asize;
        asize = bsize;
        bsize = i;
    }

    /* Use gradeschool math when either number is too small. */
    i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF;
    if (asize <= i) {
        if (asize == 0)
            return (PyLongObject *)PyLong_FromLong(0);
        else
            return x_mul(a, b);
    }

 //  下面是 Karatsuba 算法,代码较多不再贴出来

注释中简要描述了 Karatsuba 算法的步骤,如果 a、b 的长度小于一个阈值,那么就使用 gradeschool math 算法计算长整数乘法,这么做的主要原因是低于这个阈值 Karatsuba 算法不再有优势,另外,这个条件也是 k_mul 函数递归调用的退出条件,否则递归永远不会结束。更长的整数乘法才会使用 Karatsuba 算法计算,前一部分介绍 Karatsuba 算法已经看到,Karatsuba 算法是将长整数拆分为若干个短整数的乘法和加法,其中乘法会尝试递归使用 Karatsuba 算法,直到拆分到足够短的整数满足退出条件。

本文链接:http://www.thinkinpython.com/post/deep_ptyhon_vm_6.html

-- EOF --