Diff of /code/padarray.lua [000000] .. [b758a2]

Switch to unified view

a b/code/padarray.lua
1
local function dimnarrow(x,sz,pad,dim)
2
    local xn = x
3
    for i=1,x:dim() do
4
        if i > dim then
5
            xn = xn:narrow(i,pad[i]+1,sz[i])
6
        end
7
    end
8
    return xn
9
end
10
11
local function padzero(x,pad)
12
    local sz = x:size()
13
    for i=1,x:dim() do 
14
      sz[i] = sz[i]+pad[i]*2 
15
    end
16
    local xx = x.new(sz):zero()
17
    local xn = dimnarrow(xx,x:size(),pad,-1)
18
    xn:copy(x)
19
    return xx
20
end
21
22
local function padmirror(x,pad)
23
    local xx = padzero(x,pad)
24
    local sz  = xx:size()
25
    for i=1,x:dim() do
26
        local xxn = dimnarrow(xx,x:size(),pad,i)
27
        for j=1,pad[i] do
28
            xxn:select(i,j):copy(xxn:select(i,pad[i]*2-j+1))
29
            xxn:select(i,sz[i]-j+1):copy(xxn:select(i,sz[i]-pad[i]*2+j))
30
        end
31
    end
32
    return xx
33
end
34
35
function padarray(x,pad,padtype)
36
-- Example usage:  img = padarray(img,{0, pady, padx},'zero')
37
    if x:dim() ~= #pad then
38
        error('number of dimensions of Input should match number of padding sizes')
39
    end
40
    if padtype == 'zero' then return padzero(x,pad) end
41
    if padtype == 'mirror' then return padmirror(x,pad) end
42
    error('unknown paddtype ' .. padtype)
43
end