为了账号安全,请及时绑定邮箱和手机立即绑定

如何获取 Spark DataFrame 中每行列表中最大值的索引?

如何获取 Spark DataFrame 中每行列表中最大值的索引?

开心每一天1111 2022-07-12 16:27:42
我已经完成了 LDA 主题建模并将其存储在lda_model.转换我的原始输入数据集后,我检索了一个 DataFrame。其中一列是 topicDistribution,其中该行属于 LDA 模型中每个主题的概率。因此,我想获取每行列表中最大值的索引。df -- | 'list_of_words' | 'index ' | 'topicDistribution' |        ['product','...']     0       [0.08,0.2,0.4,0.0001]          .....             ...         ........我想转换 df 以便添加一个附加列,它是每行 topicDistribution 列表的 argmax。df_transformed --  | 'list_of_words' | 'index' | 'topicDistribution' | 'topicID' |                    ['product','...']     0     [0.08,0.2,0.4,0.0001]      2                       ......            ....         .....              ....我该怎么做?
查看完整描述

1 回答

?
qq_花开花谢_0

TA贡献1835条经验 获得超7个赞

您可以创建一个用户定义的函数来获取最大值的索引


from pyspark.sql import functions as f

from pyspark.sql.types import IntegerType


max_index = f.udf(lambda x: x.index(max(x)), IntegerType())

df = df.withColumn("topicID", max_index("topicDistribution"))

例子


>>> from pyspark.sql import functions as f

>>> from pyspark.sql.types import IntegerType 

>>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])

>>> df.show()

+-----------------+

|topicDistribution|

+-----------------+

|  [0.2, 0.3, 0.5]|

+-----------------+


>>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())

>>> df.withColumn("topicID", max_index("topicDistribution")).show()

+-----------------+-------+

|topicDistribution|topicID|

+-----------------+-------+

|  [0.2, 0.3, 0.5]|      2|

+-----------------+-------+

编辑:


由于您提到其中的列表topicDistribution是 numpy 数组,因此您可以更新max_index udf如下:


max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())


查看完整回答
反对 回复 2022-07-12
  • 1 回答
  • 0 关注
  • 351 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号