|
a |
|
b/code/padarray.lua |
|
|
1 |
local function dimnarrow(x,sz,pad,dim) |
|
|
2 |
local xn = x |
|
|
3 |
for i=1,x:dim() do |
|
|
4 |
if i > dim then |
|
|
5 |
xn = xn:narrow(i,pad[i]+1,sz[i]) |
|
|
6 |
end |
|
|
7 |
end |
|
|
8 |
return xn |
|
|
9 |
end |
|
|
10 |
|
|
|
11 |
local function padzero(x,pad) |
|
|
12 |
local sz = x:size() |
|
|
13 |
for i=1,x:dim() do |
|
|
14 |
sz[i] = sz[i]+pad[i]*2 |
|
|
15 |
end |
|
|
16 |
local xx = x.new(sz):zero() |
|
|
17 |
local xn = dimnarrow(xx,x:size(),pad,-1) |
|
|
18 |
xn:copy(x) |
|
|
19 |
return xx |
|
|
20 |
end |
|
|
21 |
|
|
|
22 |
local function padmirror(x,pad) |
|
|
23 |
local xx = padzero(x,pad) |
|
|
24 |
local sz = xx:size() |
|
|
25 |
for i=1,x:dim() do |
|
|
26 |
local xxn = dimnarrow(xx,x:size(),pad,i) |
|
|
27 |
for j=1,pad[i] do |
|
|
28 |
xxn:select(i,j):copy(xxn:select(i,pad[i]*2-j+1)) |
|
|
29 |
xxn:select(i,sz[i]-j+1):copy(xxn:select(i,sz[i]-pad[i]*2+j)) |
|
|
30 |
end |
|
|
31 |
end |
|
|
32 |
return xx |
|
|
33 |
end |
|
|
34 |
|
|
|
35 |
function padarray(x,pad,padtype) |
|
|
36 |
-- Example usage: img = padarray(img,{0, pady, padx},'zero') |
|
|
37 |
if x:dim() ~= #pad then |
|
|
38 |
error('number of dimensions of Input should match number of padding sizes') |
|
|
39 |
end |
|
|
40 |
if padtype == 'zero' then return padzero(x,pad) end |
|
|
41 |
if padtype == 'mirror' then return padmirror(x,pad) end |
|
|
42 |
error('unknown paddtype ' .. padtype) |
|
|
43 |
end |