为了账号安全,请及时绑定邮箱和手机立即绑定

Numba 和多维添加 - 不适用于 numpy.newaxis?

Numba 和多维添加 - 不适用于 numpy.newaxis?

慕斯王 2022-10-18 16:32:36
尝试在 python 上加速 DP 算法,numba 似乎是一个合适的候选者。我正在用提供 3D 数组的 1D 数组减去 2D 数组。然后我使用.argmin()第三维来获得一个二维数组。这适用于 numpy,但不适用于 numba。重现问题的玩具代码:from numba import jitimport numpy as npinflow      = np.arange(1,0,-0.01)                  # Dim [T]actions     = np.arange(0,1,0.05)                   # Dim [M]start_lvl   = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]disc_lvl    = np.arange(0,1000)                     # Dim [O]@jit(nopython=True)def my_func(disc_lvl, actions, start_lvl, inflow):    for i in range(0,100):        # Calculate new level at time i        new_lvl = start_lvl + inflow[i] + actions       # Dim [N x M]        # For each new_level element, find closest discretized level        diff    = (disc_lvl-new_lvl[:,:,np.newaxis])    # Dim [N x M x O]        idx_lvl = abs(diff).argmin(axis=2)              # Dim [N x M]        return True# function works fine without numbasuccess = my_func(disc_lvl, actions, start_lvl, inflow)为什么上面的代码不运行?取出时会这样@jit(nopython=True)。是否有一个工作回合可以使以下计算与 numba 一起工作?我尝试了带有 numpy repeats 和 expand_dims 的变体,以及明确定义 jit 函数的输入类型但没有成功。
查看完整描述

2 回答

?
HUX布斯

TA贡献1876条经验 获得超6个赞

您需要进行一些更改才能使其正常工作:

  1. 使用 : 为 Numba 添加维度arr[:, :, None],看起来getitem更喜欢使用reshape

  2. 使用np.abs而不是内置abs

  3. argminwithaxis关键字参数未实现。更喜欢使用 Numba 旨在优化的循环。

修复所有这些后,您可以运行 jited 函数:

from numba import jit

import numpy as np


inflow = np.arange(1,0,-0.01)  # Dim [T]

actions = np.arange(0,1,0.05)  # Dim [M]

start_lvl = np.random.rand(500).reshape(-1,1)*49  # Dim [Nx1]

disc_lvl = np.arange(0,1000)  # Dim [O]


@jit(nopython=True)

def my_func(disc_lvl, actions, start_lvl, inflow):

    for i in range(0,100):

        # Calculate new level at time i

        new_lvl = start_lvl + inflow[i] + actions  # Dim [N x M]


        # For each new_level element, find closest discretized level

        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)

        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]


        idx_lvl = np.empty(new_lvl.shape)

        for i in range(diff.shape[0]):

            for j in range(diff.shape[1]):

                idx_lvl[i, j] = diff[i, j, :].argmin()


        return True


# function works fine without numba

success = my_func(disc_lvl, actions, start_lvl, inflow)


查看完整回答
反对 回复 2022-10-18
?
翻过高山走不出你

TA贡献1875条经验 获得超3个赞

在我的第一篇文章的更正代码下方找到,您可以在使用和不使用 numba 库的 jitted 模式的情况下执行(通过删除以 @jit 开头的行)。我观察到这个例子的速度增加了 2 倍。


from numba import jit

import numpy as np

import datetime as dt


inflow = np.arange(1,0,-0.01)                       # Dim [T]

nbTime = np.shape(inflow)[0]

actions = np.arange(0,1,0.01)                       # Dim [M]

start_lvl = np.random.rand(500).reshape(-1,1)*49    # Dim [Nx1]

disc_lvl = np.arange(0,1000)                        # Dim [O]


@jit(nopython=True)

def my_func(nbTime, disc_lvl, actions, start_lvl, inflow):

    # Initialize result 

    res = np.empty((nbTime,np.shape(start_lvl)[0],np.shape(actions)[0]))


    for t in range(0,nbTime):

        # Calculate new level at time t

        new_lvl = start_lvl + inflow[t] + actions  # Dim [N x M]      

        print(t)


        # For each new_level element, find closest discretized level

        new_lvl_3d = new_lvl.reshape(*new_lvl.shape, 1)

        diff = np.abs(disc_lvl - new_lvl_3d)  # Dim [N x M x O]


        idx_lvl = np.empty(new_lvl.shape)

        for i in range(diff.shape[0]):

            for j in range(diff.shape[1]):

                idx_lvl[i, j] = diff[i, j, :].argmin()


        res[t,:,:] = idx_lvl


    return res


# Call function and print running time

start_time = dt.datetime.now()

result = my_func(nbTime, disc_lvl, actions, start_lvl, inflow)

print('Execution time :',(dt.datetime.now() - start_time))


查看完整回答
反对 回复 2022-10-18
  • 2 回答
  • 0 关注
  • 93 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信