[304dd3]: / lib / normalize.py

Download this file

14 lines (11 with data), 350 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
import torch
from torch.autograd import Variable
from torch import nn
class Normalize(nn.Module):
def __init__(self, power=2):
super(Normalize, self).__init__()
self.power = power
def forward(self, x):
norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
out = x.div(norm)
return out