为了账号安全,请及时绑定邮箱和手机立即绑定

为什么 torch.lstsq 的输出与 np.linalg.lstsq 截然不同?

为什么 torch.lstsq 的输出与 np.linalg.lstsq 截然不同?

ABOUTYOU 2023-06-13 11:14:57
Pytorch 提供了一个lstsq函数,但是它返回的结果与 numpy 的版本有很大的不同。这是一个示例输入及其结果:import numpy as npimport torch a = torch.tensor([[1., 1, 1],                  [2, 3, 4],                  [3, 5, 2],                  [4, 2, 5],                  [5, 4, 3]])b = torch.tensor([[-10., -3],                  [ 12, 14],                  [ 14, 12],                  [ 16, 16],                  [ 18, 16]])a1 = a.clone().numpy()b1 = b.clone().numpy()x, r = torch.lstsq(b, a)x1, res, r1, s = np.linalg.lstsq(b1, a1)print(f'torch_x: {x}')print(f'torch_r: {r}\n')print(f'np_x: {x1}')print(f'np_res: {res}')print(f'np_r1(rank): {r1}')print(f'np_s: {s}')输出:torch_x: tensor([[ 2.0000,  1.0000],        [ 1.0000,  1.0000],        [ 1.0000,  2.0000],        [10.9635,  4.8501],        [ 8.9332,  5.2418]])torch_r: tensor([[-7.4162, -6.7420, -6.7420],        [ 0.2376, -3.0896,  0.1471],        [ 0.3565,  0.5272,  3.0861],        [ 0.4753, -0.3952, -0.4312],        [ 0.5941, -0.1411,  0.2681]])np_x: [[-0.11452514 -0.10474861 -0.28631285] [ 0.35913807  0.33719075  0.54070234]]np_res: [ 5.4269753 10.197526   1.4185953]np_r1(rank): 2np_s: [43.057705  5.199417]我在这里错过了什么?
查看完整描述

1 回答

?
九州编程

TA贡献1785条经验 获得超4个赞

torch.lstq(a, b)求解minX L2∥bX−a∥ 同时 np.linalg.lstsq(a, b)求解minX L2∥aX−b∥


所以改变传递参数的顺序。


这是一个示例:


将 numpy 导入为 np 导入火炬


a = torch.tensor([[1., 1, 1],

                  [2, 3, 4],

                  [3, 5, 2],

                  [4, 2, 5],

                  [5, 4, 3]])


b = torch.tensor([[-10., -3],

                  [ 12, 14],

                  [ 14, 12],

                  [ 16, 16],

                  [ 18, 16]])


a1 = a.clone().numpy()

b1 = b.clone().numpy()


x, _ = torch.lstsq(a, b)


x1, res, r1, s = np.linalg.lstsq(b1, a1)


print(f'torch_x: {x[:b.shape[1]]}')


print(f'np_x: {x1}')

结果:


torch_x: tensor([[-0.1145, -0.1047, -0.2863],

        [ 0.3591,  0.3372,  0.5407]])

np_x: [[-0.11452514 -0.10474861 -0.28631285]

 [ 0.35913807  0.33719075  0.54070234]]

而且rank从 numpy.lianalg.lstsq 返回的是第一个参数的等级。要在 pytorch 使用函数中获得排名torch.matrix_rank()


查看完整回答
反对 回复 2023-06-13
  • 1 回答
  • 0 关注
  • 125 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信