Numo::NArrayで配列のパディングしたい

最近Red ChainerでCNNを実現できるようにあれこれ実装をしてるのですが、その過程でパディングをNumo::Arrayで実現する必要がありました。
パディングは畳み込みの処理をする前に入力データの周囲の固定のデータ(例えば0)を埋めることです。

f:id:hatappi1225:20180317120804p:plain

これをNumo::NArrayの配列で実現したいのが今回の目的です。
今回は4次元の配列を例にしてます。それぞれの次元では次の役割を担っています。
(バッチ数, チャネル数, 画像の高さ, 画像の幅)

numpyだとどう実現できるのか

ちょっと道を外れてNumo::NArrayがインスパイアされたnumpyではどう実現できるのかを見ていきます。
numpyにはpadと呼ばれるメソッドが用意されておりこれで実現することができます。

import numpy as np

x = np.arange(48, dtype='f').reshape(1, 3, 4, 4)
print("==== x ====")
print(x)

pad_x = np.pad(x, ((0, 0), (0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0)
print("\n==== pad_x ====")
print(pad_x)
$ python pad.py
==== x ====
[[[[ 0.  1.  2.  3.]
   [ 4.  5.  6.  7.]
   [ 8.  9. 10. 11.]
   [12. 13. 14. 15.]]

  [[16. 17. 18. 19.]
   [20. 21. 22. 23.]
   [24. 25. 26. 27.]
   [28. 29. 30. 31.]]

  [[32. 33. 34. 35.]
   [36. 37. 38. 39.]
   [40. 41. 42. 43.]
   [44. 45. 46. 47.]]]]

==== pad_x ====
[[[[ 0.  0.  0.  0.  0.  0.]
   [ 0.  0.  1.  2.  3.  0.]
   [ 0.  4.  5.  6.  7.  0.]
   [ 0.  8.  9. 10. 11.  0.]
   [ 0. 12. 13. 14. 15.  0.]
   [ 0.  0.  0.  0.  0.  0.]]

  [[ 0.  0.  0.  0.  0.  0.]
   [ 0. 16. 17. 18. 19.  0.]
   [ 0. 20. 21. 22. 23.  0.]
   [ 0. 24. 25. 26. 27.  0.]
   [ 0. 28. 29. 30. 31.  0.]
   [ 0.  0.  0.  0.  0.  0.]]

  [[ 0.  0.  0.  0.  0.  0.]
   [ 0. 32. 33. 34. 35.  0.]
   [ 0. 36. 37. 38. 39.  0.]
   [ 0. 40. 41. 42. 43.  0.]
   [ 0. 44. 45. 46. 47.  0.]
   [ 0.  0.  0.  0.  0.  0.]]]]

Numo::NArrayではどうするか?

numpyとの対応表をみてみるとpadに対応したものはまだ無いようです:cry:

Numo vs numpy · ruby-numo/numo-narray Wiki · GitHub

どう実現するか

さすがに各要素をeachしていくわけにはいかない。。。

docを見ると次のような記述が!!

(その次元の)すべての要素の参照(つまり0..-1)は true で代替できます。 高度な使い方として、残りの次元に対してすべて true を与えたいが次元の個数が不定のとき、 false を与えることもできます。

今回の例でいうとバッチサイズ、チャネルの部分ではこれが使えそうです。
最終的に次のようなコードになりました。

require 'numo/narray'

x = Numo::DFloat.new(1, 3, 4, 4).seq
p x

pad_x = Numo::DFloat.zeros(2, 3, 6, 6)
pad_x[nil, nil, 1..4, 1..4] = x[nil, nil, 0..3, 0..3]

p pad_x
$ ruby pad.rb
Numo::DFloat#shape=[1,3,4,4]
[[[[0, 1, 2, 3], 
   [4, 5, 6, 7], 
   [8, 9, 10, 11], 
   [12, 13, 14, 15]], 
  [[16, 17, 18, 19], 
   [20, 21, 22, 23], 
   [24, 25, 26, 27], 
   [28, 29, 30, 31]], 
  [[32, 33, 34, 35], 
   [36, 37, 38, 39], 
   [40, 41, 42, 43], 
   [44, 45, 46, 47]]]]
Numo::DFloat#shape=[2,3,6,6]
[[[[0, 0, 0, 0, 0, 0], 
   [0, 0, 1, 2, 3, 0], 
   [0, 4, 5, 6, 7, 0], 
   [0, 8, 9, 10, 11, 0], 
   [0, 12, 13, 14, 15, 0], 
   [0, 0, 0, 0, 0, 0]], 
  [[0, 0, 0, 0, 0, 0], 
   [0, 16, 17, 18, 19, 0], 
   [0, 20, 21, 22, 23, 0], 
   [0, 24, 25, 26, 27, 0], 
   [0, 28, 29, 30, 31, 0], 
   [0, 0, 0, 0, 0, 0]], 
  [[0, 0, 0, 0, 0, 0], 
   [0, 32, 33, 34, 35, 0], 
   [0, 36, 37, 38, 39, 0], 
   [0, 40, 41, 42, 43, 0], 
   [0, 44, 45, 46, 47, 0], 
   [0, 0, 0, 0, 0, 0]]], 
 [[[0, 0, 0, 0, 0, 0], 
   [0, 0, 1, 2, 3, 0], 
 ...

ボツ案

最初は愚直にバッチとチャネル部分でループさせてました。

require 'numo/narray'

x = Numo::DFloat.new(2, 3, 4, 4).seq
r = Numo::DFloat.zeros(2, 3, 6, 6)

x_shape = x.shape
x_shape[0].times do |b|
  x_shape[1].times do |c|
    r[b, c, 1..4, 1..4] = x[b, c, 0..3, 0..3]
  end
end

これは見た目もですが、ループさせている分のパフォーマンスも悪いです。
ベンチマークとった結果を次にのせます。

ベンチマークコード

require 'numo/narray'
require 'benchmark'

TRY_CNT=50000

Benchmark.bm(10) do |b|
  b.report('reject') do
    TRY_CNT.times do 
      x = Numo::DFloat.new(2, 3, 4, 4).seq
      r = Numo::DFloat.zeros(2, 3, 6, 6)
      x_shape = x.shape
      x_shape[0].times do |b|
        x_shape[1].times do |c|
          r[b, c, 1..4, 1..4] = x[b, c, 0..3, 0..3]
        end
      end 
    end
  end

  b.report('accept') do
    TRY_CNT.times do 
      x = Numo::DFloat.new(2, 3, 4, 4).seq
      r = Numo::DFloat.zeros(2, 3, 6, 6)
      r[nil, nil, 1..4, 1..4] = x[nil, nil, 0..3, 0..3]
    end
  end
end

結果

> ruby bench.rb
                 user     system      total        real
reject       1.695526   0.004662   1.700188 (  1.722108)
accept       0.507013   0.001961   0.508974 (  0.522136)

まとめ

パディングできた!!