我在 numpy 的帮助下编写了以下代码,我想用 numba 提高性能。我不确定为什么它不起作用,因为我已经按照 numba 系统设置了所有变量。我正在尝试加快此代码的速度,因为我将来会使用大型数据集。import numpy as npimport mathfrom numba import jitclass from numba import float64,int64spec =[ ('spacing',float64), ('n_iterations',int64), ('np_emptyhouses',float64[:,:]), ('np_agenthouses',float64[:,:]), ('similarity_threshhold',float64), ('n_changes',int64) ]@jitclass(spec)class geo_schelling_update: def __init__(self,n_iterations,spacing,np_agenthouses,np_emptyhouses,similarity_threshhold): self.spacing=spacing self.n_iterations=n_iterations self.np_emptyhouses=np_emptyhouses self.np_agenthouses=np_agenthouses self.similarity_threshhold=similarity_threshhold def distance_vectorize(self,pointA1, pointA2,agent): x_square=np.square(pointA1-agent[0]) y_square=np.square(pointA2-agent[1]) dist=np.sqrt(np.array(x_square,dtype=np.float32)+np.array(y_square,dtype=np.float32)) return np.round(dist,4) def is_unsatisfied_vectorize(self,x,y): race = np.extract(np.logical_and(np.equal(self.np_agenthouses[:,0],x),np.equal(self.np_agenthouses[:,1],y)),self.np_agenthouses[:,2])[0] euclid_distance1=round(math.hypot(self.spacing,self.spacing),4) euclid_distance2=self.spacing total_agents=np.extract(np.logical_or(np.equal(np.round(np.hypot((self.np_agenthouses[:,0]-(x)),(self.np_agenthouses[:,1]-(y))),4),euclid_distance1),np.equal(np.round(np.hypot((self.np_agenthouses[:,0]-(x)),(self.np_agenthouses[:,1]-(y))),4),euclid_distance2)),self.np_agenthouses[:,2]) if total_agents.size ==0: return False else: return total_agents[total_agents==race].size/total_agents.size<self.similarity_threshhold
1 回答

撒科打诨
TA贡献1934条经验 获得超2个赞
问题在于np.round
. 从文档中并不完全清楚,但您可以通过查看source看到,如果您在数组输入上使用该函数,则需要提供所有 3 个参数。所以以下不起作用:
nb.jit(nopython=True)def func(x): return np.round(x)
但以下工作按预期工作:
nb.jit(nopython=True)def func(x): out = np.empty_like(x) np.round(x, 0, out) return out
有关完整说明,请参阅文档。np.around
我将在 numba 问题跟踪器上提出一个问题,因为这在查看文档时并不明显。
添加回答
举报
0/150
提交
取消