为了账号安全,请及时绑定邮箱和手机立即绑定

Numpy 查找窗口数组中的最大元组

Numpy 查找窗口数组中的最大元组

噜噜哒 2023-09-26 16:16:19
我从元组列表开始(每个元组都是一个(X,Y))。我的最终结果是我想使用 numpy 找到长度为 4 的每个窗口/bin 内的最大 Y 值。# List of tuples[(0.05807200929152149, 9.9720125), (0.34843205574912894, 1.1142874), (0.6387921022067363, 2.0234027), (0.9291521486643438, 1.4435122), (1.207897793263647, 2.3677678), (1.4982578397212543, 1.9457655), (1.7886178861788617, 2.8343441), (2.078977932636469, 5.7816567), ...]# Convert to numpy arraydt = np.dtype('float,float')arr = np.asarray(listTuples, dt)# [(0.05807201, 9.97201252) (0.34843206, 1.11428738)#  (0.6387921 , 2.02340269) (0.92915215, 1.4435122 )#  (1.20789779, 2.36776781) (1.49825784, 1.9457655 )#  (1.78861789, 2.83434415) (2.07897793, 5.78165674)#  (2.36933798, 3.14842606) ...]#Create windows/blocks of 4 elementsarr = arr.reshape(-1,4)# [[(0.05807201, 9.97201252) (0.34843206, 1.11428738)#   (0.6387921 , 2.02340269) (0.92915215, 1.4435122 )]#  [(1.20789779, 2.36776781) (1.49825784, 1.9457655 )#   (1.78861789, 2.83434415) (2.07897793, 5.78165674)]#  [(2.36933798, 3.14842606) (2.95005807, 2.10357308)#   (3.24041812, 1.15985966) (3.51916376, 2.03056955)]...]print(arr.max(axis=1)) <-- ERROR HEREprint(max(arr,key=lambda x:x[1])) <-- ERROR, tried this too but doesn't work我想要使用最大 y 值从每个窗口/块获得的预期输出如下。或者,格式可以是常规元组列表,并不严格需要是 numpy 数组:[[(0.05807201, 9.97201252)][(2.07897793, 5.78165674)][(2.36933798, 3.14842606)]...]OR other format:[(0.05807201, 9.97201252),(2.07897793, 5.78165674),(2.36933798, 3.14842606)]...]
查看完整描述

3 回答

?
牧羊人nacy

TA贡献1862条经验 获得超7个赞

这是一种矢量化(可能是最快的)方法:


arr = np.asarray(listTuples, np.dtype('float,float'))

idx = arr.view(('float',2))[:,1].reshape(-1,4).argmax(1)

arr = arr.reshape(-1,4)[np.arange(len(idx)),idx]

#[(0.05807201, 9.9720125) (2.07897793, 5.7816567) ...]

您基本上使用结构化数组的数组(非结构化)版本,并view沿 查找 Y 的 argmax axis=1。然后使用这些索引来过滤原始数组中的元组arr。


查看完整回答
反对 回复 2023-09-26
?
HUX布斯

TA贡献1876条经验 获得超6个赞

这应该可以解决你的问题。


输入:元组列表


输出:元组列表,取每个 4 个元素块中 y 值最大的元组 import numpy as np


# List of tuples

listTuples = [(1,1),(120,1000),(12,90),(1,1),(0.05807200929152149, 9.9720125), 

(0.34843205574912894, 1.1142874), (0.6387921022067363, 2.0234027), 

(0.9291521486643438, 1.4435122), (1.207897793263647, 2.3677678), 

(1.4982578397212543, 1.9457655), (1.7886178861788617, 2.8343441), 

(2.078977932636469, 5.7816567)]



def extractMaxY(li):

    result = []

    index = 0

    for i in range(0,len(li), 4):

        max = -100000

#find the max Y in blocks of 4

        for j in range(4):

            if li[i+j][1] > max:

                max = li[i+j][1]

                index = i+j

        result.append(li[index])

    return result



print(extractMaxY(listTuples))

然后输出是


[(120, 1000), (0.05807200929152149, 9.9720125), (2.078977932636469, 

5.7816567)]

应该如此,对吗?


查看完整回答
反对 回复 2023-09-26
?
有只小跳蛙

TA贡献1824条经验 获得超8个赞

你可以试试这个


import numpy as np

ans = []

for value, comp in zip(x, np.max(x, axis=1)[:,1]):

    ans.append([(i, j) for i, j in value if np.isclose(j,comp)])

print(np.array(ans))

将其应用于您的数据


x =  np.array([[(0.05807201, 9.97201252), (0.34843206, 1.11428738),

  (0.6387921 , 2.02340269), (0.92915215, 1.4435122 )],

 [(1.20789779, 2.36776781), (1.49825784, 1.9457655 ),

  (1.78861789, 2.83434415), (2.07897793, 5.78165674)],

 [(2.36933798, 3.14842606), (2.95005807, 2.10357308),

  (3.24041812, 1.15985966), (3.51916376, 2.03056955)]])

退货


[[[ 0.05807201  9.97201252]]


 [[ 2.07897793  5.78165674]]


 [[ 2.36933798  3.14842606]]]



查看完整回答
反对 回复 2023-09-26
  • 3 回答
  • 0 关注
  • 76 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信