ブロードキャスティング
(皆さんにとっては、『そんなの当たり前!』と思っているかも知れませんが
私自身すごく忘れっぽいものでから・・・) (苦笑)
W = np.array([[0, 1, 2], [3, 4, 5]])
x = np.array([6, 7, 8])
1 print(np.ones((2,3))*x) # x:(3,) -> (1,3) -> (2,3)
[[6. 7. 8.]
[6. 7. 8.]]
2 print(np.ones((2,3))*x[np.newaxis, :]) # x:(1,3) -> (2,3)
[[6. 7. 8.]
[6. 7. 8.]]
3 print(np.ones((2,3))*x[:, np.newaxis]) # x:(3,1) -> (2,3) unable
shapeで考えます。 (2, 3) * (3, 1) ⇒ 出力次元 (2, 3) で (3, 1) の1次元の1が
出力の3と不一致でブロードキャスティングできない
4 print(np.ones((3,2))*x) # x:(3,) -> (1,3) -> (3,2) unable
(3, 2) * (1, 3) ⇒ (3, 2) の出力の1次元2と (1, 3) の3が不一致で
ブロードキャスティングできない
5 print(np.ones((3,2))*x[np.newaxis, :]) # x:(1,3) -> (3,2) unable
上と同じ理由でブロードキャスティングできない
6 print(np.ones((3,2))*x[:, np.newaxis]) # x:(3,1) -> (3,2)
[[6. 6.]
[7. 7.]
[8. 8.]]
7 print(np.ones((3,3)) * x)
[[6. 7. 8.]
[6. 7. 8.]
[6. 7. 8.]]
8 print(x[:, np.newaxis] * np.ones((3,3)))
[[6. 6. 6.]
[7. 7. 7.]
[8. 8. 8.]]
9 print(x[:, np.newaxis] * x)
[[36 42 48]
[42 49 56]
[48 56 64]]
10 print(W.T + x[np.newaxis, :].T)
[[ 6 9]
[ 8 11]
[10 13]]
この記事へのコメントはありません。