V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
推荐学习书目
Learn Python the Hard Way
Python Sites
PyPI - Python Package Index
http://diveintopython.org/toc/index.html
Pocoo
值得关注的项目
PyPy
Celery
Jinja2
Read the Docs
gevent
pyenv
virtualenv
Stackless Python
Beautiful Soup
结巴中文分词
Green Unicorn
Sentry
Shovel
Pyflakes
pytest
Python 编程
pep8 Checker
Styles
PEP 8
Google Python Style Guide
Code Style from The Hitchhiker's Guide
dwjgwsm
V2EX  ›  Python

为什么我用 numba 速度不升反降?

  •  
  •   dwjgwsm · 2018-04-04 14:25:14 +08:00 · 4852 次点击
    这是一个创建于 2414 天前的主题,其中的信息可能已经有所发展或是发生改变。

    看了这篇文章 https://zhuanlan.zhihu.com/p/24168485 试了一下里面的 ma_numba 函数

    import time

    @numba.jit

    def ma_numba(data, ma_length):

    ma = []
    data_window = data[:ma_length]
    test_data = data[ma_length:]
    
    for new_tick in test_data:
        data_window.pop(0)
        data_window.append(new_tick)
        sum_tick = 0
        for tick in data_window:
            sum_tick += tick
        ma.append(sum_tick/ma_length)
    
    
    a = np.arange(10000)
    t1 = time.time()
    b = list(a)
    bb = ma_numba(b, 5)
    t2 = time.time()
    print(t2 - t1)
    
    
    不用 numba,大概耗时 0.03-0.04 秒,用了 numba,耗时 0.7-0.8 秒......奇了怪了,难道是我的姿势不对?
    
    17 条回复    2018-04-05 11:49:43 +08:00
    neoblackcap
        1
    neoblackcap  
       2018-04-04 14:29:12 +08:00 via iPhone
    np 不是本身就是 c 写的吗?你用在这里大概是 jit 也没抵消类型转换啊之类的开销吧。
    要不你用个 pyflame 看看哪里开销大?
    dwjgwsm
        2
    dwjgwsm  
    OP
       2018-04-04 14:36:28 +08:00
    第一,a = np.arange(10000) 这一句是排除在耗时计算之外的.
    第二,b = list(a) 这一句是都被计入耗时之内的.所以对比是不存在这个问题的
    dwjgwsm
        3
    dwjgwsm  
    OP
       2018-04-04 14:39:06 +08:00
    对比时就是简单地把 @numba.jit 这一句注释掉和不注释掉
    ipwx
        4
    ipwx  
       2018-04-04 14:53:24 +08:00
    你这代码本来就不科学啊。data_window.pop 你这是想干嘛啊?还有 sum_tick 有你这种写法嘛?好好的 O(n) 算法你给写成 O(n*k) ?

    In [1]: import numpy as np

    In [2]: import numba

    In [3]: def moving_average(data, k):
    ...: partial_sum = sum(data[:k])
    ...: ret = [partial_sum / k]
    ...: for old_d, new_d in zip(data[:-k], data[k:]):
    ...: partial_sum = partial_sum - old_d + new_d
    ...: ret.append(partial_sum / k)
    ...: return ret
    ...:

    In [4]: numba_moving_average = numba.jit(moving_average)

    In [5]: arr = np.arange(10000)


    In [6]: arr_list = list(arr)

    In [7]: %timeit moving_average(arr_liset)

    In [8]: %timeit moving_average(arr_list, 5)
    3.8 ms ± 9.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

    In [9]: %timeit numba_moving_average(arr_list, 5)
    722 µs ± 35.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    dwjgwsm
        5
    dwjgwsm  
    OP
       2018-04-04 15:20:03 +08:00
    不对啊,我这结果还是 numba 耗时长啊

    t1 = time.time()
    b=numba_moving_average(a,5)
    t2 = time.time()
    c=moving_average(a,5)
    t3 = time.time()
    print(t2-t1)
    print(t3 - t2)

    结果:

    0.7720441818237305
    0.008000373840332031
    enenaaa
        6
    enenaaa  
       2018-04-04 15:39:04 +08:00 via Android
    我也发现了这个情况。numba 和 numpy、cython 混用时耗时不降反升。推测是多种格式数据通过解释器互转效率低下。
    ipwx
        7
    ipwx  
       2018-04-04 15:47:39 +08:00
    @dwjgwsm

    In [8]: arr_list = list(np.arange(100000))

    In [10]: t1 = time.time(); moving_average(arr_list, 5); t2 = time.time(); numba_moving_average(arr_list, 5); t3 = time.time()

    In [11]: (t2 - t1, t3 - t2)
    Out[11]: (0.0019309520721435547, 0.23806500434875488)

    In [12]: t1 = time.time(); moving_average(arr_list, 5); t2 = time.time(); numba_moving_average(arr_list, 5); t3 = time.time()

    In [13]: (t2 - t1, t3 - t2)
    Out[13]: (0.0016407966613769531, 0.005582094192504883)

    In [14]: t1 = time.time(); [moving_average(arr_list, 5) for i in range(100)]; t2 = time.time(); [numba_moving_average(arr_list, 5) for i in range(100)]; t3 = time.time()

    In [15]: (t2 - t1, t3 - t2)
    Out[15]: (0.18658995628356934, 0.12822914123535156)

    In [16]: t1 = time.time(); [moving_average(arr_list, 5) for i in range(1000)]; t2 = time.time(); [numba_moving_average(arr_list, 5) for i in range(1000)]; t3 = time.time()

    In [17]: (t2 - t1, t3 - t2)
    Out[17]: (1.3983790874481201, 1.3098900318145752)
    dwjgwsm
        8
    dwjgwsm  
    OP
       2018-04-04 16:00:42 +08:00
    你这个结果也不乐观.看来还是混用不行. 后面再去折腾一下 cython 看看
    necomancer
        9
    necomancer  
       2018-04-04 16:50:22 +08:00   ❤️ 1
    我觉得 7# 说很清楚了吧,一般没有用 time.time() - start 来测试的,除非你程序大概跑在分钟级,data 大个一百万倍再说吧,timeit 是比较合适的测时间的工具。

    还有,我想吐槽这个专栏,4# (同一人哎)说得更清楚,这个专栏是来逗比的么……写个移动平均当例子把 o(n) 弄成 o(n*k),这蛋疼的 pop(0)

    更吐槽的是,还说第一反应上 NumPy,还 numpy_right ……
    为啥不用 np.convolve(data, np.ones(500)/500,mode='valid') 试试?
    20.6 ms ± 93.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 是渣渣 i7-3687u 的结果,一样的 data size (100000), 他的 cython 版本说单次时间最快也就 0.0098s 也就是 9.8ms ,这货认真的? numpy 对他来说用法仅限于 a.mean() 和方便的索引了是么……
    necomancer
        10
    necomancer  
       2018-04-04 16:53:57 +08:00
    仔细看一下连 cython 里都还有 pop(0)……这个大哥仗着自己是 i7-6700k 就日了天了么……
    DSaAAiC
        11
    DSaAAiC  
       2018-04-04 17:09:16 +08:00
    你的代码跑 10000 遍使用了 numba.jit 是 6.175 秒,不使用 numba.jit 是 72.809 秒。numba 的 jit 技术还是起到作用了。
    DSaAAiC
        12
    DSaAAiC  
       2018-04-04 17:09:53 +08:00
    @necomancer 大佬都是从哪里知道这些偏僻的 numpy 函数的,系统地看文档吗?
    necomancer
        13
    necomancer  
       2018-04-04 17:23:16 +08:00
    顺便再扯一嘴,用 convole 还是一般带窗口的,像这个方窗的情况
    ```
    def maa(data, n):
    ret = np.cumsum(data)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n-1:]/n
    ```
    渣渣本上也只要 990 微秒。这些不好好看 NumPy 的同学弄得好像 python 咋折腾都很低效……
    necomancer
        14
    necomancer  
       2018-04-04 17:29:08 +08:00   ❤️ 1
    @DSaAAiC 我不是大佬……而且这个问题不能算生僻吧,移动平均,尤其是带有窗口函数的移动平均,遇到得应该还是很多的。我其实看到“移动平均”第一反应是“这其实是个卷积的问题”,当然这么想问题也会复杂化,卷积是 o(n*k),当然一些大窗口体系还能用更快的 fftconvole ……扯远了,知道 NumPy 里都有啥好玩儿的需要一定的数学基础吧,我感觉,把遇到的问题能比较“数学地”进行描述,NumPy/SciPy 总会有惊喜。一般来说都是 Google 一下 问题+scipy 就会看到好玩儿的函数在下面贴着。
    liyuanji1002
        15
    liyuanji1002  
       2018-04-05 02:41:06 +08:00
    不知道在哪看的了, 说是 jit 启动需要花费一点时间. 可能你这段代码的计算规模还是低了点~ 试试把规模再翻几十倍看看如何.
    dwjgwsm
        16
    dwjgwsm  
    OP
       2018-04-05 10:56:22 +08:00
    @liyuanji1002 算了,运算量整太大了,脱离实际需求也没有意义了.反正优化方案里面已经 pass 掉 numba 了
    xgdgsc
        17
    xgdgsc  
       2018-04-05 11:49:43 +08:00
    nopython=True, cache=True 看看
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   5620 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 23ms · UTC 06:06 · PVG 14:06 · LAX 22:06 · JFK 01:06
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.