Numba加速python代码
发布日期:2021-06-30 20:33:02 浏览次数:2 分类:技术文章

本文共 2218 字,大约阅读时间需要 7 分钟。

Numba 读取装饰函数的 Python 字节码,并将其与有关函数输入参数类型的信息结合起来,分析和优化代码,最后使用编译器库(LLVM)针对你的 CPU 生成量身定制的机器代码。每次调用函数时,都会使用此编译版本,来达到加速的目的。

1 原始代码

import mathimport timedef is_prime(num):    if num == 2:        return True    if num <= 1 or not num % 2:        return False    for div in range(3, int(math.sqrt(num) + 1), 2):        if not num % div:            return False    return Truedef run_program(N):    total = 0    for i in range(N):        if is_prime(i):            total += 1    return totalif __name__ == "__main__":    N = 10000000    start = time.time()    total = run_program(N)    end = time.time()    print(f"total prime num is {total}")    print(f"cost {end - start}s")

运行结果:

total prime num is 664579    cost 53.910285234451294s

2 导入 Numba 的 njit,再在函数上方放个装饰器 @njit

import mathimport timefrom numba import njit# @njit 相当于 @jit(nopython=True)@njitdef is_prime(num):    if num == 2:        return True    if num <= 1 or not num % 2:        return False    for div in range(3, int(math.sqrt(num) + 1), 2):        if not num % div:            return False    return True# @njit 相当于 @jit(nopython=True)@njit(parallel = True)def run_program(N):    total = 0    for i in range(N):        if is_prime(i):            total += 1    return totalif __name__ == "__main__":    N = 10000000    start = time.time()    total = run_program(N)    end = time.time()    print(f"total prime num is {total}")    print(f"cost {end - start}s")

运行结果:

total prime num is 664579    cost 3.5231616497039795s

3 加入prange 参数来并行计算

import mathimport timefrom numba import njit, prange# @njit 相当于 @jit(nopython=True)@njitdef is_prime(num):    if num == 2:        return True    if num <= 1 or not num % 2:        return False    for div in range(3, int(math.sqrt(num) + 1), 2):        if not num % div:            return False    return True# @njit 相当于 @jit(nopython=True)@njit(parallel = True)def run_program(N):    total = 0    for i in prange(N):        if is_prime(i):            total += 1    return totalif __name__ == "__main__":    N = 10000000    start = time.time()    total = run_program(N)    end = time.time()    print(f"total prime num is {total}")    print(f"cost {end - start}s")

运行结果:

total prime num is 664579    cost 1.087864875793457s

转载地址:https://liumin.blog.csdn.net/article/details/115893084 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:PyTorch代码优化技巧
下一篇:修改yolov5的输入图像尺寸为指定尺寸

发表评论

最新留言

能坚持,总会有不一样的收获!
[***.219.124.196]2024年04月15日 10时00分07秒