Switch to side-by-side view

--- a
+++ b/dl/affinitynet/graph_attention.py
@@ -0,0 +1,1276 @@
+import functools
+import collections
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+
+from ..models.transformer import *
+
+if torch.cuda.is_available():
+    dtype = {'float': torch.cuda.FloatTensor, 'long': torch.cuda.LongTensor, 'byte': torch.cuda.ByteTensor} 
+else:
+    dtype = {'float': torch.FloatTensor, 'long': torch.LongTensor, 'byte': torch.ByteTensor} 
+
+
+def get_iterator(x, n, forced=False):
+    r"""If x is int, copy it to a list of length n
+    Cannot handle a special case when the input is an iterable and len(x) = n, 
+    but we still need to copy it to a list of length n
+    """
+    if forced:
+        return [x] * n
+    if not isinstance(x, collections.Iterable) or isinstance(x, str):
+        x = [x] * n
+    # Note: np.array, list are always iterable
+    if len(x) != n:
+        x = [x] * n
+    return x
+
+def get_partial_model(model_part, model):
+    pretrained_state_dict = {k: v for k, v in model.state_dict().items() if k in model_part.state_dict()}
+    state_dict = model_part.state_dict()
+    state_dict.update(pretrained_state_dict)
+    model_part.load_state_dict(state_dict)
+
+    
+class DenseLinear(nn.Module):
+    r"""Multiple linear layers densely connected
+    
+    Args:
+        in_dim: int, number of features
+        hidden_dim: iterable of int
+        nonlinearity: default nn.ReLU()
+                        can be changed to other nonlinear activations
+        last_nonlinearity: if True, apply nonlinearity to the last output; default False
+        dense: if dense, concatenate all previous intermediate features to current input
+        forward_input: should the original input be concatenated to current input used when dense is True
+                        if return_all is True and return_layers is None and forward_input is True, 
+                            then concatenate input with all hidden outputs as final output
+        return_all: if True return all layers
+        return_layers: selected layers to output; used only when return_all is True
+        bias: if True, use bias in nn.Linear()
+        
+    Shape:
+    
+    Attributes:
+        A series on weight and bias 
+    
+    Examples:
+    
+    >>> m = DenseLinear(3, [3,4], return_all=True)
+    >>> x = Variable(torch.randn(4,3))
+    >>> m(x)
+    """
+    def __init__(self, in_dim, hidden_dim, nonlinearity=nn.ReLU(), last_nonlinearity=False, dense=True,
+                forward_input=False, return_all=False, return_layers=None, bias=True):
+        super(DenseLinear, self).__init__()
+        num_layers = len(hidden_dim)
+        nonlinearity = get_iterator(nonlinearity, num_layers)
+        bias = get_iterator(bias, num_layers)
+        self.forward_input = forward_input
+        self.return_all = return_all
+        self.return_layers = return_layers
+        self.dense = dense
+        self.last_nonlinearity = last_nonlinearity
+        
+        self.layers = nn.Sequential()
+        cnt_dim = in_dim if forward_input else 0
+        for i, h in enumerate(hidden_dim):
+            self.layers.add_module('linear'+str(i), nn.Linear(in_dim, h, bias[i]))
+            if i < num_layers-1 or last_nonlinearity:
+                self.layers.add_module('activation'+str(i), nonlinearity[i])
+            cnt_dim += h
+            in_dim = cnt_dim if dense else h
+            
+    def forward(self, x):
+        if self.forward_input:
+            y = [x]
+        else:
+            y = []
+        out = x
+        for n, m in self.layers._modules.items():
+            out = m(out)
+            if n.startswith('activation'):
+                y.append(out)
+                if self.dense:
+                    out = torch.cat(y, dim=-1)
+        if self.return_all:
+            if not self.last_nonlinearity: # add last output even if there is no nonlinearity
+                y.append(out)
+            if self.return_layers is not None:
+                return_layers = [i%len(y) for i in self.return_layers]
+                y = [h for i, h in enumerate(y) if i in return_layers]
+            return torch.cat(y, dim=-1)
+        else:
+            return out
+        
+    
+class FineTuneModel(nn.Module):
+    r"""Finetune the last layer(s) (usually newly added) with a pretained model to learn a representation
+    
+    Args:
+        pretained_model: nn.Module, pretrained module
+        new_layer: nn.Module, newly added layer
+        freeze_pretrained: if True, set requires_grad=False for pretrained_model parameters
+        
+    Shape:
+        - Input: (N, *)
+        - Output: 
+        
+    Attributes:
+        All model parameters of pretrained_model and new_layer
+    
+    Examples:
+    
+        >>> m = nn.Linear(2,3)
+        >>> model = FineTuneModel(m, nn.Linear(3,2))
+        >>> x = Variable(torch.ones(1,2))
+        >>> print(m(x))
+        >>> print(model(x))
+        >>> print(FeatureExtractor(model, [0,1])(x))
+    """
+    def __init__(self, pretrained_model, new_layer, freeze_pretrained=True):
+        super(FineTuneModel, self).__init__()
+        self.pretrained_model = pretrained_model
+        self.new_layer = new_layer
+        if freeze_pretrained:
+            for p in self.pretrained_model.parameters():
+                p.requires_grad = False
+                
+    def forward(self, x):
+        return self.new_layer(self.pretrained_model(x))
+    
+    
+class FeatureExtractor(nn.Module):
+    r"""Extract features from different layers of the model
+    
+    Args:
+        model: nn.Module, the model
+        selected_layers: an iterable of int or 'string' (as module name), selected layers
+        
+    Shape:
+        - Input: (N,*)
+        - Output: a list of Variables, depending on model and selected_layers
+        
+    Attributes: 
+        None learnable
+       
+    Examples:
+    
+        >>> m = nn.Sequential(nn.Linear(2,2), nn.Linear(2,3))
+        >>> m = FeatureExtractor(m, [0,1])
+        >>> x = Variable(torch.randn(1, 2))
+        >>> m(x)
+    """
+    def __init__(self, model, selected_layers=None, return_list=False):
+        super(FeatureExtractor, self).__init__()
+        self.model = model
+        self.selected_layers = selected_layers
+        if self.selected_layers is None:
+            self.selected_layers = range(len(model._modules))
+        self.return_list = return_list
+    
+    def set_selected_layers(self, selected_layers):
+        self.selected_layers = selected_layers
+        
+    def forward(self, x):
+        out = []
+        for i, (name, m) in enumerate(self.model._modules.items()):
+            x = m(x)
+            if i in self.selected_layers or name in self.selected_layers:
+                out.append(x)
+        if self.return_list:
+            return out
+        else:
+            return torch.cat(out, dim=-1)
+    
+
+class WeightedFeature(nn.Module):
+    r"""Transform features into weighted features
+    
+    Args:
+        num_features: int
+        reduce: if True, return weighted mean
+        
+    Shape: 
+        - Input: (N, *, num_features) where * means any number of dimensions
+        - Output: (N, *, num_features) if reduce is False (default) else (N, *)
+        
+    Attributes:
+        weight: (num_features)
+        
+    Examples::
+    
+        >>> m = WeightedFeature(10)
+        >>> x = torch.autograd.Variable(torch.randn(5,10))
+        >>> out = m(x)
+        >>> print(out)
+    """
+    def __init__(self, num_features, reduce=False, magnitude=None):
+        super(WeightedFeature, self).__init__()
+        self.reduce = reduce
+        self.weight = nn.Parameter(torch.empty(num_features))
+        # initialize with uniform weight
+        self.weight.data.fill_(1)
+        self.magnitude = 1 if magnitude is None else magnitude
+    
+    def forward(self, x):
+        self.normalized_weight = torch.nn.functional.softmax(self.weight, dim=0)
+        # assert x.shape[-1] == self.normalized_weight.shape[0]
+        out = x * self.normalized_weight * self.magnitude
+        if self.reduce:
+            return out.sum(-1)
+        else:
+            return out
+
+        
+class WeightedView(nn.Module):
+    r"""Calculate weighted view
+    
+    Args:
+        num_groups: int, number of groups (views)
+        reduce_dimension: bool, default False. If True, reduce dimension dim
+        dim: default -1. Only used when reduce_dimension is True
+        
+    Shape: 
+        - Input: if dim is None, (N, num_features*num_groups)
+        - Output: (N, num_features)
+        
+    Attributes:
+        weight: (num_groups)
+        
+    Examples:
+    
+        >>> model = WeightedView(3)
+        >>> x = Variable(torch.randn(1, 6))
+        >>> print(model(x))
+        >>> model = WeightedView(3, True, 1)
+        >>> model(x.view(1,3,2))
+    """
+    def __init__(self, num_groups, reduce_dimension=False, dim=-1):
+        super(WeightedView, self).__init__()
+        self.num_groups = num_groups
+        self.reduce_dimension = reduce_dimension
+        self.dim = dim
+        self.weight = nn.Parameter(torch.Tensor(num_groups))
+        self.weight.data.uniform_(-1./num_groups, 1./num_groups)
+    
+    def forward(self, x):
+        self.normalized_weight = nn.functional.softmax(self.weight, dim=0)
+        if self.reduce_dimension:
+            assert x.size(self.dim) == self.num_groups
+            dim = self.dim if self.dim>=0 else self.dim+x.dim()
+            if dim == x.dim()-1:
+                out = (x * self.weight).sum(-1)
+            else:
+                # this is tricky for the case when x.dim()>3
+                out = torch.transpose((x.transpose(dim,-1)*self.normalized_weight).sum(-1), dim, -1)
+        else:
+            assert x.dim() == 2
+            num_features = x.size(-1) // self.num_groups
+            out = (x.view(-1, self.num_groups, num_features).transpose(1, -1)*self.normalized_weight).sum(-1)
+        return out    
+
+    
+class AffinityKernel(nn.Module):
+    r"""Calculate new representation for each point based on its k-nearest-neighborhood
+    
+    Args:
+        in_dim: int
+        hidden_dim: int
+        out_dim: int or None
+                 not used if interaction_only is True
+        interaction_only: if True, not use out_dim at all
+        pooling: 'average' or 'max', use AveragePooling or MaxPooling for the neighborhood
+        
+        k, graph, feature_subset are the same with GraphAttentionLayer, 
+            except that now we implicitly set out_indices=None (output will have shape (N, *))
+   
+    Shape:
+        - Input: (N, in_dim) 
+        - Output: (N, out_dim)
+        
+    Attributes:
+        w: ((2*in_dim), hidden_dim)
+        w2: ((in_dim+hidden_dim), out_dim), if interaction_only is True, then parameters w2 is None
+        
+    Examples:
+        >>> m = AffinityKernel(5, 10, 15)
+        >>> x = Variable(torch.randn(10,5))
+        >>> m(x)
+    
+    """
+    def __init__(self, in_dim, hidden_dim, out_dim, k=None, graph=None, feature_subset=None,
+                 nonlinearity_1=nn.Hardtanh(), nonlinearity_2=None, interaction_only=False, 
+                pooling='average', reset_graph_every_forward=False, out_indices=None):
+        super(AffinityKernel, self).__init__()
+        self.k = k
+        self.graph = graph
+        self.cal_graph = True if self.graph is None else False
+        self.feature_subset = feature_subset
+        self.nonlinearity_1 = nonlinearity_1
+        self.nonlinearity_2 = nonlinearity_2
+        self.pooling = pooling
+        self.reset_graph_every_forward = reset_graph_every_forward
+        self.out_indices = out_indices
+        assert self.pooling=='average' or self.pooling=='max'
+        
+        self.w = nn.Parameter(torch.Tensor(hidden_dim, 2*in_dim))
+        std = 1./np.sqrt(self.w.size(1))
+        self.w.data.uniform_(-std, std)
+        self.w2 = None
+        if not interaction_only:
+            assert isinstance(out_dim, int)
+            self.w2 = nn.Parameter(torch.Tensor(out_dim, in_dim+hidden_dim))
+            std = 1./np.sqrt(self.w2.size(1))
+            self.w2.data.uniform_(-std, std)
+            
+    def reset_graph(self, graph=None):
+        self.graph = graph
+        self.cal_graph = True if self.graph is None else False
+        
+    def reset_out_indices(self, out_indices=None):
+        self.out_indices = out_indices
+    
+    def forward(self, x):
+        N, in_dim = x.size()
+        out = Variable(torch.zeros(N, self.w.size(0)).type(dtype['float']))
+        k = self.k if isinstance(self.k, int) and self.k<x.size(0) else x.size(0)
+        
+        # Had not check this carefully
+        if self.reset_graph_every_forward:
+            self.reset_graph()
+            self.reset_out_indices()
+
+        if self.cal_graph: # probably redudant attribute; should only self.graph
+            if self.feature_subset is None:
+                feature_subset = dtype['long'](range(x.size(1)))
+            else:
+                feature_subset = self.feature_subset
+            d = torch.norm(x[:,feature_subset] - x[:,feature_subset].unsqueeze(1), dim=-1)
+            _, self.graph = torch.topk(d, k, dim=-1, largest=False)
+            
+        for i in range(N):
+            neighbor_idx = self.graph[i][:k]
+            neighbor_mat = torch.cat([x[neighbor_idx], x[i,None]*Variable(torch.ones(len(neighbor_idx), 1).type(
+                dtype['float']))], dim=1)
+            h = nn.functional.linear(neighbor_mat, self.w)
+            if self.nonlinearity_1 is not None:
+                h = self.nonlinearity_1(h)
+            if self.pooling == 'average':
+                out[i] = h.mean(dim=0)
+            elif self.pooling == 'max':
+                # torch.max() returns a tuple
+                out[i] = h.max(dim=0)[0]
+                
+        if self.w2 is not None:
+            out = nn.functional.linear(torch.cat([out, x], dim=-1), self.w2)
+        if self.nonlinearity_2 is not None:
+            out = self.nonlinearity_2(out)
+
+        out_indices = range(N) if self.out_indices is None else self.out_indices
+        return out[out_indices]
+
+    
+class AffinityNet(nn.Module):
+    r"""Multiple AffinityKernel layers
+        Same interface, except that the input should be iterable when appropriate and with a new argument:
+        return_all
+    
+    Args:
+        return_all: if true, return concatenated features from all AffinityKernel Layers
+        add_global_feature: if true, add global features at the last of the output
+                            only used when return_all is true
+        k_avg: when performing global pooling on the last layer, how many neighbors should we use for pooling
+               if k_avg is None, then use all nodes for global pooling. Otherwise, it is "local" pooling
+        global_pooling: 'average' or 'max' pooling
+        pool_last_layer_only: if True, only pool last layer as global feature, 
+                                otherwise pool all previous concacted output (and input if forward_input is true)
+                              only used when return_all is True
+        forward_input: if True, add input in the beginning of the output
+                       only used when return_all is True
+        dense: if True, feed all previous input and output as current input
+               inspired by DenseNet
+        in_dim: int
+        hidden_dim: iterable of int
+        out_dim: iterable of int; 
+                 if initialized int or None, then transform it to iterable of length hidden_dim
+        k: iterable of int; process it similar to out_dim
+        use_initial_graph: if True, calculate graph from input once use it for subsequent layers
+        reset_graph_every_forward: if True, reset graph, out_indices, k_avg in the beginning of every forward
+        out_indices: default None, output.size(0)==x.size(0)
+                        if not None, output.size(0)==len(out_indices)
+        k_avg_graph: either 'single' or 'mix'; 
+                    if 'single', use the provided graph only for pooling;
+                    if 'mix', append calculated graph based on current features to the provided graph 
+                        in case the provided graph has a node degree less than k_avg
+                    only used when return_all, add_global_feature, k_avg < x.size(0) are all True, and 
+                     the provided graph is not a torch.LongTensor or Variable
+        graph, non_linearity_1, non_linearity_2, feature_subset, interaction_only, pooling, 
+            are all the same as those in AffinityKernel except that they will be iterables
+    
+    Shape:
+        - Input: (N, *, in_dim)
+        - Out: (N, ?) ? to be determined by hidden_dim, out_dim and return_all
+        
+    Attributes:
+        a list of parameters of AffinityKernel
+        
+    Examples:
+    
+        >>> m = AffinityNet(5, [10,3,5], [7,3,4], return_all=True)
+        >>> x = Variable(torch.randn(1,5))
+        >>> m(x)
+    """
+    def __init__(self, in_dim, hidden_dim, out_dim, k=None, graph=None, feature_subset=None,
+                 nonlinearity_1=nn.Hardtanh(), nonlinearity_2=None, interaction_only=False, 
+                pooling='average', return_all=False, add_global_feature=True, k_avg=None, 
+                 global_pooling='max', pool_last_layer_only=True,
+                forward_input=True, dense=True, use_initial_graph=True, reset_graph_every_forward=False,
+                 out_indices=None, k_avg_graph='single'):
+        super(AffinityNet, self).__init__()
+        self.return_all = return_all
+        self.add_global_feature = add_global_feature
+        self.global_pooling = global_pooling
+        self.pool_last_layer_only = pool_last_layer_only
+        self.k_avg = k_avg
+        self.forward_input = forward_input
+        self.dense = dense
+        self.use_initial_graph = use_initial_graph
+        self.reset_graph_every_forward = reset_graph_every_forward
+        self.out_indices = out_indices
+        self.k_avg_graph = k_avg_graph
+        
+        assert self.global_pooling=='average' or self.global_pooling=='max'
+        assert self.k_avg_graph=='single' or self.k_avg_graph=='mix'
+        
+        num_layers = len(hidden_dim)
+        self.num_layers = num_layers
+        out_dim = get_iterator(out_dim, num_layers)
+        k = get_iterator(k, num_layers)
+        graph = get_iterator(graph, num_layers)
+        self.graph = graph
+        feature_subset = get_iterator(feature_subset, num_layers) # should be None almost all the time
+        nonlinearity_1 = get_iterator(nonlinearity_1, num_layers)
+        nonlinearity_2 = get_iterator(nonlinearity_2, num_layers)
+        interaction_only = get_iterator(interaction_only, num_layers)
+        pooling = get_iterator(pooling, num_layers)
+        
+        self.features = nn.ModuleList()
+        for i in range(num_layers):
+            self.features.append(
+                AffinityKernel(in_dim=in_dim, hidden_dim=hidden_dim[i], out_dim=out_dim[i],
+                              k=k[i], graph=graph[i], feature_subset=feature_subset[i], 
+                               nonlinearity_1=nonlinearity_1[i], nonlinearity_2=nonlinearity_2[i],
+                              interaction_only=interaction_only[i], pooling=pooling[i],
+                              reset_graph_every_forward=False, out_indices=None))
+            
+            new_dim = hidden_dim[i] if interaction_only[i] else out_dim[i]
+            if self.dense:
+                if i == 0 and not self.forward_input:
+                    in_dim = new_dim
+                else:
+                    in_dim += new_dim
+            else:
+                in_dim = new_dim
+                
+    def reset_graph(self, graph=None):
+        graph = get_iterator(graph, self.num_layers)
+        for i in range(self.num_layers):
+            getattr(self.features, str(i)).reset_graph(graph[i])
+        self.graph = graph
+    
+    def reset_k_avg(self, k_avg=None):
+        self.k_avg = k_avg
+        
+    def reset_out_indices(self, out_indices=None):
+        self.out_indices = out_indices
+        # all previous layers out_indices is None
+        # could be wrong; Did not check carefully
+        for i in range(self.num_layers):
+            getattr(self.features, str(i)).reset_out_indices()
+    
+    def forward(self, x):
+        N = x.size(0)
+        
+        # this condition might be buggy
+        if self.reset_graph_every_forward:
+            self.reset_graph()
+            self.reset_k_avg()
+            self.reset_out_indices()
+            
+        if self.graph[0] is None and self.use_initial_graph:
+            d = torch.norm(x-x[:,None], dim=-1)
+            _, graph = d.sort()
+            self.reset_graph(graph)
+            
+        if self.forward_input:
+            y = [x]
+        else:
+            y = []
+        out = x
+        for f in self.features:
+            out = f(out)   
+            y.append(out)
+            if self.dense:
+                out = torch.cat(y, -1)
+                
+        # Very tricky; still not clear if I have done right
+        out_indices = range(N) if self.out_indices is None else self.out_indices
+        
+        if self.return_all:
+            if self.add_global_feature:
+                pool_feature = y[-1] if self.pool_last_layer_only else out
+                dim_pool = 0
+                if isinstance(self.k_avg, int) and self.k_avg < N:
+                    if self.graph[-1] is None:
+                        d = torch.norm(x-x[:,None], dim=-1)
+                        _, graph = d.sort()
+                    else:
+                        # when graph is given or set
+                        graph = self.graph[-1]
+                        assert len(graph) == N
+                        # handling the case when graph is a list of torch.LongTensor
+                        # the size of neighborhood of each node may vary
+                        if not isinstance(graph, (dtype['long'], Variable)): 
+                            # save some computation if graph is already a torch.LongTensor or Variable
+                            if self.k_avg_graph == 'single':
+                                graph = torch.stack([dtype['long']([g[i%len(g)] for i in range(N)])
+                                             for g in graph], dim=0)
+                            elif self.k_avg_graph == 'mix': # very tricky here; spent quite some time debugging
+                                d = torch.norm(x-x[:,None], dim=-1)
+                                _, graph2 = d.sort()
+                                graph = torch.stack([torch.cat(
+                                    [dtype['long'](g), dtype['long'](
+                                        [i for i in graph2[j].data if i not in g])])
+                                                     for j, g in enumerate(graph)], dim=0)
+                            
+                    pool_feature = (pool_feature[graph[:,:self.k_avg].contiguous().view(-1)].
+                                    contiguous().view(N, self.k_avg, -1))
+                    dim_pool=1
+                if self.global_pooling == 'average':
+                    global_feature = pool_feature.mean(dim=dim_pool)
+                elif self.global_pooling == 'max':
+                    # torch.max() return a tuple
+                    global_feature = pool_feature.max(dim=dim_pool)[0]
+                if dim_pool == 0:
+                    global_feature = global_feature * Variable(torch.ones(x.size(0),1).type(dtype['float']))
+                y.append(global_feature)
+            return torch.cat(y, -1)[out_indices]
+        else:
+            return y[-1][out_indices]
+
+        
+class StackedAffinityNet(nn.Module): 
+    r"""Stack multiple simple AffinityNet layers with bottleneck layers in the middle
+    enable concatenating the output of intermediate output as output
+    For simplification, each AffinityNet unit have the same hidden_dim and out_dim
+    
+    
+    Args:
+        L: number of layers within each AffinityNet unit
+        max_dim: the maximum dimension produced by bottleneck layer, can be an iterable
+        forward_input_global: if True, add original input to the head of output
+                              only used when return_all_global is true
+        return_all_global: if True, return all intermediate features (and input if forward_input_global is True)
+        dense_global: if True, the output of previous bottleneck layers (extracted features) with be concatenated
+                        with current input
+        set_bottleneck_dim: if True, every bottleneck layer will be determined by max_dim only
+        return_layers: If not None, then only output of certain bottleneck layers
+                        only used when return_all_global is True
+                        Very buggy when interact with forward_global_input
+        hierarchical_pooling: if True, set k_avg = round(np.exp(np.log(N)/num_blocks))
+                                where N = x.size(0), num_blocks = len(hidden_dim)
+       
+    Shape:
+    
+    Attributes:
+    
+    Examples:
+    
+    >>> m = StackedAffinityNet(2, [2,3], [2,3], 3)
+    >>> x = Variable(torch.randn(5,2))
+    >>> m(x)
+    """
+    def __init__(self, in_dim, hidden_dim, out_dim, L, k=None, graph=None, feature_subset=None,
+                 nonlinearity_1=nn.Hardtanh(), nonlinearity_2=None, interaction_only=False, 
+                pooling='average', return_all=True, add_global_feature=True, k_avg=None, 
+                 global_pooling='max', pool_last_layer_only=True,
+                forward_input=True, dense=True, max_dim=10, set_bottleneck_dim=True, forward_input_global=False,
+                dense_global=True, return_all_global=True, return_layers=None, use_initial_graph=True,
+                 hierarchical_pooling=True, reset_graph_every_forward=False,
+                out_indices=None, k_avg_graph='single'):
+        super(StackedAffinityNet, self).__init__()
+        assert isinstance(hidden_dim, collections.Iterable)
+        num_blocks = len(hidden_dim)
+        self.num_blocks = num_blocks
+        out_dim = get_iterator(out_dim, num_blocks)
+        k = get_iterator(k, num_blocks)
+        graph = get_iterator(graph, num_blocks)
+        self.graph = graph
+        feature_subset = get_iterator(feature_subset, num_blocks) # should be None almost all the time
+        nonlinearity_1 = get_iterator(nonlinearity_1, num_blocks)
+        nonlinearity_2 = get_iterator(nonlinearity_2, num_blocks)
+        interaction_only = get_iterator(interaction_only, num_blocks)
+        pooling = get_iterator(pooling, num_blocks)
+        return_all = get_iterator(return_all, num_blocks)
+        add_global_feature = get_iterator(add_global_feature, num_blocks)
+        k_avg = get_iterator(k_avg, num_blocks)
+        self.k_avg = k_avg
+        global_pooling = get_iterator(global_pooling, num_blocks)
+        pool_last_layer_only = get_iterator(pool_last_layer_only, num_blocks)
+        forward_input = get_iterator(forward_input, num_blocks)
+        dense = get_iterator(dense, num_blocks)
+        max_dim = get_iterator(max_dim, num_blocks)
+        self.forward_input_global = forward_input_global
+        self.dense_global = dense_global
+        self.return_all_global = return_all_global
+        self.return_layers = return_layers
+        self.use_initial_graph = use_initial_graph
+        self.hierarchical_pooling = hierarchical_pooling
+        self.reset_graph_every_forward = reset_graph_every_forward
+        self.out_indices = out_indices
+        
+        self.blocks = nn.ModuleList()
+        dim_sum = 0
+        for i in range(num_blocks):
+            self.blocks.append(
+                AffinityNet(in_dim=in_dim, hidden_dim=[hidden_dim[i]]*L, out_dim=out_dim[i], k=k[i], 
+                            graph=graph[i], feature_subset=feature_subset[i], nonlinearity_1=nonlinearity_1[i],
+                            nonlinearity_2=nonlinearity_2[i], interaction_only=interaction_only[i], 
+                            pooling=pooling[i], return_all=return_all[i],
+                            add_global_feature=add_global_feature[i], k_avg=k_avg[i], 
+                            global_pooling=global_pooling[i], pool_last_layer_only=pool_last_layer_only[i],
+                            forward_input=forward_input[i], dense=dense[i], use_initial_graph=use_initial_graph,
+                           reset_graph_every_forward=False, out_indices=None, k_avg_graph=k_avg_graph)
+            )
+            if return_all[i]:
+                new_dim = hidden_dim[i]*L if interaction_only[i] else out_dim[i]*L
+                if forward_input[i]:
+                    new_dim += in_dim
+                if add_global_feature[i]:
+                    if pool_last_layer_only:
+                        new_dim += hidden_dim[i] if interaction_only[i] else out_dim[i]
+                    else:
+                        new_dim *= 2            
+            else:
+                new_dim = hidden_dim[i] if interaction_only[i] else out_dim[i]
+            
+            if dense_global:
+                new_dim += dim_sum
+            in_dim = max_dim[i] if set_bottleneck_dim else min(new_dim, max_dim[i])
+            # use linear layer or AffinityNet or AffinityKernel?
+            self.blocks.add_module('bottleneck'+str(i),
+                                   nn.Sequential(
+                                       nn.Linear(new_dim, in_dim),
+                                       nonlinearity_1[i]
+                                   ))
+            dim_sum += in_dim
+            
+    def reset_graph(self, graph=None):
+        # could be buggy here
+        # assume every block consists of exactly two layers: an AffinityNet and and a bottleneck layer
+        graph = get_iterator(graph, self.num_blocks)
+        for i in range(self.num_blocks):
+            getattr(self.blocks, str(i*2)).reset_graph(graph[i])
+        self.graph = graph
+            
+    def reset_k_avg(self, k_avg=None):
+        # similar to reset_graph 
+        # could be buggy here
+        # assume every block consists of exactly two layers: an AffinityNet and and a bottleneck layer
+        k_avg = get_iterator(k_avg, self.num_blocks)
+        for i in range(self.num_blocks):
+            getattr(self.blocks, str(i*2)).reset_k_avg(k_avg[i])
+        self.k_avg = k_avg
+    
+    def reset_out_indices(self, out_indices=None):
+        self.out_indices = out_indices
+        # Very Very buggy here; hadn't check it carefully
+        # out_indices should be None util the last layer
+        for i in range(self.num_blocks):
+            getattr(self.blocks, str(i*2)).reset_out_indices()
+    
+    def forward(self, x):
+        if self.reset_graph_every_forward:
+            self.reset_graph()
+            self.reset_k_avg()
+            self.reset_out_indices()
+            
+        if self.graph[0] is None and self.use_initial_graph:
+            d = torch.norm(x-x[:,None], dim=-1)
+            _, graph = d.sort()
+            self.reset_graph(graph)
+            
+        if self.k_avg[0] is None and self.hierarchical_pooling:
+            k = int(round(np.exp(np.log(x.size(0))/self.num_blocks)))
+            ks = [k]
+            for i in range(self.num_blocks-1):
+                if i == self.num_blocks-2:
+                    ks.append(x.size(0))  # pool all points in last layer
+                else:
+                    ks.append(ks[-1]*k)
+            self.reset_k_avg(ks)
+            
+        y = []
+        out = x
+        for name, module in self.blocks._modules.items():
+            if name.startswith('bottleneck') and self.dense_global:
+                out = torch.cat(y+[out], -1)
+            out = module(out)
+            if name.startswith('bottleneck'):
+                y.append(out)
+        
+        # this is very buggy; I had been debugging this for a long time
+        # still not clear if I get it correctly
+        out_indices = range(x.size(0)) if self.out_indices is None else self.out_indices
+        
+        if self.return_all_global:
+            if self.forward_input_global:
+                y = [x] + y
+            if isinstance(self.return_layers, collections.Iterable):
+                y = [h for i, h in enumerate(y) if i in self.return_layers]
+            return torch.cat(y, -1)[out_indices]
+        else:
+            return y[-1][out_indices]
+
+
+class GraphAttentionLayer(nn.Module):
+    r"""Attention layer
+    
+    Args:
+        in_dim: int, dimension of input
+        out_dim: int, dimension of output
+        out_indices: torch.LongTensor, the indices of nodes whose representations are 
+                     to be computed
+                     Default None, calculate all node representations
+                     If not None, need to reset it every time model is run
+        feature_subset: torch.LongTensor. Default None, use all features
+        kernel: 'affine' (default), use affine function to calculate attention 
+                'gaussian', use weighted Gaussian kernel to calculate attention
+        k: int, number of nearest-neighbors used for calculate node representation
+           Default None, use all nodes
+        graph: a list of torch.LongTensor, corresponding to the nearest neighbors of nodes 
+               whose representations are to be computed
+               Make sure graph and out_indices are aligned properly
+        use_previous_graph: only used when graph is None
+                            if True, to calculate graph use input
+                            otherwise, use newly transformed output
+        nonlinearity_1: nn.Module, non-linear activations followed by linear layer 
+        nonlinearity_2: nn.Module, non-linear activations followed after attention operation
+    
+    Shape:
+        - Input: (N, in_dim) graph node representations
+        - Output: (N, out_dim) if out_indices is None 
+                  else (len(out_indices), out_dim)
+        
+    Attributes:
+        weight: (out_dim, in_dim)
+        a: out_dim if kernel is 'gaussian' 
+           out_dim*2 if kernel is 'affine'
+           
+    Examples:
+    
+        >>> m = GraphAttentionLayer(2,2,feature_subset=torch.LongTensor([0,1]), 
+                        graph=torch.LongTensor([[0,5,1], [3,4,6]]), out_indices=[0,1], 
+                        kernel='gaussian', nonlinearity_1=None, nonlinearity_2=None)
+        >>> x = Variable(torch.randn(10,3))
+        >>> m(x)
+    """
+    def __init__(self, in_dim, out_dim, k=None, graph=None, out_indices=None, 
+                 feature_subset=None, kernel='affine', nonlinearity_1=nn.Hardtanh(),
+                 nonlinearity_2=None, use_previous_graph=True, reset_graph_every_forward=False,
+                no_feature_transformation=False, rescale=True, layer_norm=False, layer_magnitude=100,
+                key_dim=None, feature_selection_only=False):
+        super(GraphAttentionLayer, self).__init__()
+        self.in_dim = in_dim
+        self.graph = graph
+        if graph is None:
+            self.cal_graph = True
+        else:
+            self.cal_graph = False
+        self.use_previous_graph = use_previous_graph
+        self.reset_graph_every_forward = reset_graph_every_forward
+        self.no_feature_transformation = no_feature_transformation
+        if self.no_feature_transformation:
+            assert in_dim == out_dim
+        else:
+            self.weight = nn.Parameter(torch.Tensor(out_dim, in_dim))
+            # initialize parameters
+            std = 1. / np.sqrt(self.weight.size(1))
+            self.weight.data.uniform_(-std, std)
+        self.rescale = rescale
+        self.k = k
+        self.out_indices = out_indices
+        self.feature_subset = feature_subset
+        self.kernel = kernel
+        self.nonlinearity_1 = nonlinearity_1
+        self.nonlinearity_2 = nonlinearity_2
+        self.layer_norm = layer_norm
+        self.layer_magnitude = layer_magnitude
+        self.feature_selection_only = feature_selection_only
+
+        if kernel=='affine':
+            self.a = nn.Parameter(torch.Tensor(out_dim*2))
+        elif kernel=='gaussian' or kernel=='inner-product' or kernel=='avg_pool' or kernel=='cosine':
+            self.a = nn.Parameter(torch.Tensor(out_dim))
+        elif kernel=='key-value':
+            if key_dim is None:
+                self.key = None
+                key_dim = out_dim
+            else:
+                if self.use_previous_graph:
+                    self.key = nn.Linear(in_dim, key_dim)
+                else:
+                    self.key = nn.Linear(out_dim, key_dim)
+            self.key_dim = key_dim
+            self.a = nn.Parameter(torch.Tensor(out_dim))
+        else:
+            raise ValueError('kernel {0} is not supported'.format(kernel))
+        self.a.data.uniform_(0, 1)
+    
+    def reset_graph(self, graph=None):
+        self.graph = graph
+        self.cal_graph = True if self.graph is None else False
+        
+    def reset_out_indices(self, out_indices=None):
+        self.out_indices = out_indices
+    
+    def forward(self, x):
+        if self.reset_graph_every_forward:
+            self.reset_graph()
+            
+        N = x.size(0)
+        out_indices = dtype['long'](range(N)) if self.out_indices is None else self.out_indices
+        if self.feature_subset is not None:
+            x = x[:, self.feature_subset]
+        assert self.in_dim == x.size(1)
+         
+        if self.no_feature_transformation:
+            out = x
+        else:
+            out = nn.functional.linear(x, self.weight)
+        
+        feature_weight = nn.functional.softmax(self.a, dim=0) 
+        if self.rescale and self.kernel!='affine':
+            out = out*feature_weight
+            if self.feature_selection_only:
+                return out
+
+        if self.nonlinearity_1 is not None:
+            out = self.nonlinearity_1(out)
+        k = N if self.k is None else min(self.k, out.size(0))
+
+        if self.kernel=='key-value':
+            if self.key is None:
+                keys = x if self.use_previous_graph else out
+            else:
+                keys = self.key(x) if self.use_previous_graph else self.key(out)
+            norm = torch.norm(keys, p=2, dim=-1)
+            att = (keys[out_indices].unsqueeze(-2) * keys.unsqueeze(-3)).sum(-1) / (norm[out_indices].unsqueeze(-1)*norm)
+            att_, idx = att.topk(k, -1)
+            a = Variable(torch.zeros(att.size()).fill_(float('-inf')).type(dtype['float']))
+            a.scatter_(-1, idx, att_)
+            a = nn.functional.softmax(a, dim=-1)
+            y = (a.unsqueeze(-1)*out.unsqueeze(-3)).sum(-2)
+            if self.nonlinearity_2 is not None:
+                y = self.nonlinearity_2(y)
+            if self.layer_norm:
+                y = nn.functional.relu(y)  # maybe redundant; just play safe
+                y = y / y.sum(-1, keepdim=True) * self.layer_magnitude # <UncheckAssumption> y.sum(-1) > 0
+            return y
+
+        # The following line is BUG: self.graph won't update after the first update
+        # if self.graph is None
+        # replaced with the following line
+        if self.cal_graph:
+            if self.kernel != 'key-value':
+                features = x if self.use_previous_graph else out
+                dist = torch.norm(features.unsqueeze(1)-features.unsqueeze(0), p=2, dim=-1)
+                _, self.graph = dist.sort()
+                self.graph = self.graph[out_indices]               
+        y = Variable(torch.zeros(len(out_indices), out.size(1)).type(dtype['float']))
+        
+        for i, idx in enumerate(out_indices):
+            neighbor_idx = self.graph[i][:k]
+            if self.kernel == 'gaussian':
+                if self.rescale: # out has already been rescaled
+                    a = -torch.sum((out[idx] - out[neighbor_idx])**2, dim=1)
+                else:
+                    a = -torch.sum((feature_weight*(out[idx] - out[neighbor_idx]))**2, dim=1)
+            elif self.kernel == 'inner-product':
+                if self.rescale: # out has already been rescaled
+                    a = torch.sum(out[idx]*out[neighbor_idx], dim=1)
+                else:
+                    a = torch.sum(feature_weight*(out[idx]*out[neighbor_idx]), dim=1)
+            elif self.kernel == 'cosine':
+                if self.rescale: # out has already been rescaled
+                    norm = torch.norm(out[idx]) * torch.norm(out[neighbor_idx], p=2, dim=-1)
+                    a = torch.sum(out[idx]*out[neighbor_idx], dim=1) / norm
+                else:
+                    norm = torch.norm(feature_weight*out[idx]) * torch.norm(feature_weight*out[neighbor_idx], p=2, dim=-1)
+                    a = torch.sum(feature_weight*(out[idx]*out[neighbor_idx]), dim=1) / norm
+            elif self.kernel == 'affine':
+                a = torch.mv(torch.cat([(out[idx].unsqueeze(0) 
+                                         * Variable(torch.ones(len(neighbor_idx)).unsqueeze(1)).type(dtype['float'])), 
+                                        out[neighbor_idx]], dim=1), self.a)
+            elif self.kernel == 'avg_pool':
+                a = Variable(torch.ones(len(neighbor_idx)).type(dtype['float']))
+            a = nn.functional.softmax(a, dim=0)
+            # since sum(a)=1, the following line should torch.sum instead of torch.mean
+            y[i] = torch.sum(out[neighbor_idx]*a.unsqueeze(1), dim=0)
+        if self.nonlinearity_2 is not None:
+            y = self.nonlinearity_2(y)
+        if self.layer_norm:
+            y = nn.functional.relu(y)  # maybe redundant; just play safe
+            y = y / y.sum(-1, keepdim=True) * self.layer_magnitude # <UncheckAssumption> y.sum(-1) > 0
+        return y
+        
+        
+class GraphAttentionModel(nn.Module):
+    r"""Consist of multiple GraphAttentionLayer
+    
+    Args:
+        in_dim: int, num_features
+        hidden_dims: an iterable of int, len(hidden_dims) is number of layers
+        ks: an iterable of int, k for GraphAttentionLayer. 
+            Default None, use all neighbors for all GraphAttentionLayer
+        kernels, graphs, nonlinearities_1, nonlinearities_2, feature_subsets, out_indices, use_previous_graphs: 
+            an iterable of * for GraphAttentionLayer
+        
+    Shape:
+        - Input: (N, in_dim)
+        - Output: (x, hidden_dims[-1]), x=N if out_indices is None. Otherwise determined by out_indices
+    
+    Attributes:
+        weights: a list of weight for GraphAttentionLayer
+        a: a list of a for GraphAttentionLayer
+    
+    Examples:
+    
+        >>> m=GraphAttentionModel(5, [3,4], [3,3])
+        >>> x = Variable(torch.randn(10,5))
+        >>> m(x)
+    """
+    def __init__(self, in_dim, hidden_dims, ks=None, graphs=None, out_indices=None, feature_subsets=None,
+                 kernels='affine', nonlinearities_1=nn.Hardtanh(), nonlinearities_2=None,
+                 use_previous_graphs=True, reset_graph_every_forward=False, no_feature_transformation=False,
+                rescale=True):
+        super(GraphAttentionModel, self).__init__()
+        self.in_dim = in_dim
+        self.hidden_dims = hidden_dims
+        num_layers = len(hidden_dims)
+        self.no_feature_transformation = get_iterator(no_feature_transformation, num_layers)
+        for i in range(num_layers):
+            if self.no_feature_transformation[i]:
+                if i == 0:
+                    assert hidden_dims[0] == in_dim
+                else:
+                    assert hidden_dims[i-1] == hidden_dims[i]
+                
+        if ks is None or isinstance(ks, int):
+            ks = [ks]*num_layers
+        self.ks = ks
+        if graphs is None:
+            graphs = [None]*num_layers
+        self.graphs = graphs
+        self.reset_graph_every_forward = reset_graph_every_forward
+        if isinstance(kernels, str):
+            kernels = [kernels]*num_layers
+        self.kernels = kernels
+        if isinstance(nonlinearities_1, nn.Module) or nonlinearities_1 is None:
+            nonlinearities_1 = [nonlinearities_1]*num_layers
+        # Tricky: if nonlinearities_1 is an instance of nn.Module, then nonlinearities_1 will become a 
+        # child module of self. Reassignment will have to be a nn.Module
+        self.nonlinearities_1 = nonlinearities_1
+        if isinstance(nonlinearities_2, nn.Module) or nonlinearities_2 is None:
+            nonlinearities_2 = [nonlinearities_2]*num_layers
+        self.nonlinearities_2 = nonlinearities_2
+        self.out_indices = out_indices
+        if isinstance(out_indices, dtype['long']) or out_indices is None:
+            self.out_indices = [out_indices]*num_layers
+        self.feature_subsets = feature_subsets
+        if isinstance(feature_subsets, dtype['long']) or feature_subsets is None:
+            self.feature_subsets = [feature_subsets]*num_layers
+        self.use_previous_graphs = use_previous_graphs
+        if isinstance(use_previous_graphs, bool):
+            self.use_previous_graphs = [use_previous_graphs]*num_layers
+        self.rescale = get_iterator(rescale, num_layers)
+            
+        self.attention = nn.Sequential()
+        for i in range(num_layers):
+            self.attention.add_module('layer'+str(i), 
+                GraphAttentionLayer(in_dim if i==0 else hidden_dims[i-1], out_dim=hidden_dims[i], 
+                                    k=self.ks[i], graph=self.graphs[i], out_indices=self.out_indices[i],
+                                    feature_subset=self.feature_subsets[i], kernel=self.kernels[i],
+                                    nonlinearity_1=self.nonlinearities_1[i],
+                                    nonlinearity_2=self.nonlinearities_2[i], 
+                                    use_previous_graph=self.use_previous_graphs[i],
+                                   no_feature_transformation=self.no_feature_transformation[i],
+                                   rescale=self.rescale[i]))
+            
+    def reset_graph(self, graph=None):
+        num_layers = len(self.hidden_dims)
+        graph = get_iterator(graph, num_layers)
+        for i in range(num_layers):
+            getattr(self.attention, 'layer'+str(i)).reset_graph(graph[i])
+        self.graphs = graph  
+            
+    def reset_out_indices(self, out_indices=None):
+        num_layers = len(self.hidden_dims)
+        out_indices = get_iterator(out_indices, num_layers)
+        assert len(out_indices) == num_layers
+        for i in range(num_layers):
+            # probably out_indices should not be a list;
+            # only the last layer will output certain points, all previous ones should output all points
+            getattr(self.attention, 'layer'+str(i)).reset_out_indices(out_indices[i])
+        self.out_indices = out_indices
+            # functools.reduce(lambda m, a: getattr(m, a), ('attention.layer'+str(i)).split('.'), self).reset_out_indices(out_indices[i])
+        
+    def forward(self, x):
+        if self.reset_graph_every_forward:
+            self.reset_graph()
+            
+        return self.attention(x)
+    
+    
+class GraphAttentionGroup(nn.Module):
+    r"""Combine different view of data
+    
+    Args:
+        group_index: an iterable of torch.LongTensor or other type that can be subscripted by torch.Tensor;
+                     each element is feed to GraphAttentionModel as feature_subset
+        merge: if True, aggregate the output of each group (view);
+               Otherwise, concatenate the output of each group
+        in_dim: only used when group_index is None, otherwise determined by group_index
+        feature_subset: not used when group_index is not None: always set to None internally
+        out_dim, k, graph, out_indices, kernel, nonlinearity_1, nonlinearity_2, and
+            use_previous_graph are used similarly in GraphAttentionLayer
+            
+    Shape:
+        - Input: (N, in_dim)
+        - Output: (x, y) where x=N if out_indices is None len(out_indices)
+                              y=out_dim if merge is True else out_dim*len(group_index)
+                              
+    Attributes:
+        weight: (out_dim, in_dim) 
+        a: (out_dim) if kernel='gaussian' else (out_dim * 2)
+        group_weight: (len(group_index)) if merge is True else None
+        
+    Examples:
+    
+        >>> m = GraphAttentionGroup(2, 2, k=None, graph=None, out_indices=None, 
+                 feature_subset=None, kernel='affine', nonlinearity_1=nn.Hardtanh(),
+                 nonlinearity_2=None, use_previous_graph=True, group_index=[range(2), range(2,4)], merge=False)
+        >>> x = Variable(torch.randn(5, 4))
+        >>> m(x)
+    """
+    def __init__(self, in_dim, out_dim, k=None, graph=None, out_indices=None, 
+                 feature_subset=None, kernel='affine', nonlinearity_1=nn.Hardtanh(),
+                 nonlinearity_2=None, use_previous_graph=True, group_index=None, merge=True,
+                 merge_type='sum', reset_graph_every_forward=False, no_feature_transformation=False,
+                rescale=True, merge_dim=None, layer_norm=False, layer_magnitude=100, key_dim=None):
+        super(GraphAttentionGroup, self).__init__()
+        self.group_index = group_index
+        num_groups = 0 if self.group_index is None else len(group_index) 
+        self.num_groups = num_groups
+        self.merge = merge
+        assert merge_type=='sum' or merge_type=='affine'
+        self.merge_type = merge_type
+        
+        self.components = nn.ModuleList()
+        self.group_weight = None
+        self.feature_weight = None
+        if group_index is None or len(group_index)==1:
+            self.components.append(GraphAttentionLayer(in_dim, out_dim, k, graph, out_indices, feature_subset,
+                                                       kernel, nonlinearity_1, nonlinearity_2,
+                                                       use_previous_graph,
+                                                       reset_graph_every_forward=False,
+                                                       no_feature_transformation=no_feature_transformation, 
+                                                       rescale=rescale, layer_norm=layer_norm,
+                                                       layer_magnitude=layer_magnitude, key_dim=key_dim))
+        else:
+            self.out_dim = get_iterator(out_dim, num_groups)
+            self.k = get_iterator(k, num_groups)
+            # BUG here: did not handle a special case where len(graph) = num_groups
+            self.graph = get_iterator(graph, num_groups)
+            # all groups' output have the same first dimention
+            self.out_indices = out_indices
+            # each group use all of its own features
+            self.feature_subset = None
+            self.kernel = get_iterator(kernel, num_groups, isinstance(kernel, str))
+            self.nonlinearity_1 = get_iterator(nonlinearity_1, num_groups)
+            self.nonlinearity_2 = get_iterator(nonlinearity_2, num_groups)
+            self.use_previous_graph = get_iterator(use_previous_graph, num_groups)
+            self.layer_norm = get_iterator(layer_norm, num_groups)
+            self.layer_magnitude = get_iterator(layer_magnitude, num_groups)
+            self.key_dim = get_iterator(key_dim, num_groups)
+            for i, idx in enumerate(group_index):
+                self.components.append(
+                    GraphAttentionLayer(len(idx), self.out_dim[i], self.k[i], self.graph[i],
+                                        self.out_indices, self.feature_subset, self.kernel[i],
+                                        self.nonlinearity_1[i], self.nonlinearity_2[i], 
+                                        self.use_previous_graph[i],
+                                        reset_graph_every_forward=False,
+                                        no_feature_transformation=no_feature_transformation,
+                                        rescale=rescale, layer_norm=self.layer_norm[i],
+                                        layer_magnitude=self.layer_magnitude[i],
+                                        key_dim=self.key_dim[i]))
+            if self.merge:
+                self.merge_dim = merge_dim if isinstance(merge_dim, int) else self.out_dim[0]
+                if self.merge_type=='sum':
+                    # all groups' output should have the same dimension
+                    for i in self.out_dim:
+                        assert i==self.merge_dim
+                    self.group_weight = nn.Parameter(torch.Tensor(num_groups))
+                    self.group_weight.data.uniform_(-1/num_groups,1/num_groups)
+                elif self.merge_type=='affine':
+                    # This is ugly and buggy
+                    # Do not assume each view have the same out_dim, finally output merge_dim
+                    # if merge_dim is None then set merge_dim=self.out_dim[0]
+                    self.feature_weight = nn.Parameter(torch.Tensor(self.merge_dim, sum(self.out_dim)))
+                    self.feature_weight.data.uniform_(-1./sum(self.out_dim), 1./sum(self.out_dim))
+                
+    def reset_graph(self, graph=None):
+        graphs = get_iterator(graph, self.num_groups)
+        for i, graph in enumerate(graphs):
+            getattr(self.components, str(i)).reset_graph(graph)
+        self.graph = graphs
+                
+    def reset_out_indices(self, out_indices=None):
+        num_groups = len(self.group_index)
+        out_indices = get_iterator(out_indices, num_groups)
+        for i in range(num_groups):
+            getattr(self.components, str(i)).reset_out_indices(out_indices[i])
+        self.out_indices = out_indices
+                
+    def forward(self, x):
+        if self.group_index is None or len(self.group_index)==1:
+            return self.components[0](x)
+        N = x.size(0) if self.out_indices is None else len(self.out_indices)
+        out = Variable(torch.zeros(N, functools.reduce(lambda x,y:x+y, self.out_dim)).type(dtype['float']))
+            
+        j = 0
+        for i, idx in enumerate(self.group_index):
+            out[:, j:j+self.out_dim[i]] = self.components[i](x[:,idx])
+            j += self.out_dim[i]
+            
+        if self.merge:
+            out_dim = self.merge_dim
+            num_groups = len(self.out_dim)
+            y = Variable(torch.zeros(N, out_dim).type(dtype['float']))
+            if self.merge_type == 'sum':
+                # normalize group weight
+                self.group_weight_normalized = nn.functional.softmax(self.group_weight, dim=0)
+                # Warning: cannot change y inplace, eg. y += something (and y = y+something?)
+                y = (self.group_weight_normalized.unsqueeze(1) * out.view(N, num_groups, out_dim)).sum(1)
+            elif self.merge_type == 'affine':
+                y = nn.functional.linear(out, self.feature_weight)
+            return y
+        else:
+            return out
+        
+
+class MultiviewAttention(nn.Module):
+    r"""Stack GraphAttentionGroup layers; 
+        For simplicity, assume for each layer, the parameters of each group has the same shape
+    
+    Args:
+        Has the same interface with GraphAttentionGroup, except
+            merge: a list of bool variable; default None, set it [False, False, ..., False, True] internally 
+            hidden_dims: must be an iterable of int (len(hidden_dims) == num_layers) 
+                                                or iterable (len(hidden_dims[0]) == num_views)
+
+        Warnings:
+            Be careful to use out_indices, feature_subset, can be buggy
+           
+    Shape:
+        - Input: (N, *)
+        - Output: 
+    
+    Attributes:
+        Variables of each GraphAttentionGroupLayer
+    
+    Examples:
+    
+        >>> m = MultiviewAttention(4, [3,2], group_index=[range(2), range(2,4)])
+        >>> x = Variable(torch.randn(1, 4))
+        >>> print(m(x))
+        >>> model = FeatureExtractor(m.layers, [0,1])
+        >>> print(model(x))
+    """
+    def __init__(self, in_dim, hidden_dims, k=None, graph=None, out_indices=None, 
+                 feature_subset=None, kernel='affine', nonlinearity_1=nn.Hardtanh(),
+                 nonlinearity_2=None, use_previous_graph=True, group_index=None, merge=None,
+                merge_type='sum', reset_graph_every_forward=False, no_feature_transformation=False,
+                rescale=True, merge_dim=None, layer_norm=False, layer_magnitude=100, 
+                key_dim=None):
+        super(MultiviewAttention, self).__init__()
+        assert isinstance(in_dim, int)
+        assert isinstance(hidden_dims, collections.Iterable)
+        self.reset_graph_every_forward = reset_graph_every_forward
+        self.hidden_dims = hidden_dims
+        num_layers = len(hidden_dims)
+        self.num_layers = num_layers
+        if group_index is None:
+            group_index = [range(in_dim)] if feature_subset is None else [feature_subset]
+        if merge is None:
+            merge = [False]*(num_layers-1) + [True]
+        elif isinstance(merge, bool):
+            merge = get_iterator(merge, num_layers)
+        out_indices = get_iterator(out_indices, num_layers)
+        k = get_iterator(k, num_layers)
+        no_feature_transformation = get_iterator(no_feature_transformation, num_layers)
+        rescale = get_iterator(rescale, num_layers)
+        # buggy here: interact with merge
+        merge_dim = get_iterator(merge_dim, num_layers)
+
+        if layer_norm is True:
+            layer_norm = [True]*(num_layers-1) + [False]
+        layer_norm = get_iterator(layer_norm, num_layers)
+        layer_magnitude = get_iterator(layer_magnitude, num_layers)
+        key_dim = get_iterator(key_dim, num_layers)
+
+        self.layers = nn.Sequential()
+        for i in range(num_layers):
+            self.layers.add_module(str(i),
+                GraphAttentionGroup(in_dim, hidden_dims[i], k[i], graph, out_indices[i], None, 
+                                    kernel, nonlinearity_1, nonlinearity_2, use_previous_graph, 
+                                    group_index, merge[i], merge_type, reset_graph_every_forward=False,
+                                    no_feature_transformation=no_feature_transformation[i], 
+                                    rescale=rescale[i], merge_dim=merge_dim[i],
+                                    layer_norm=layer_norm[i], layer_magnitude=layer_magnitude[i],
+                                    key_dim=key_dim[i]))
+            # Very Very buggy here
+            # assume hidden_dims[i] is int or [int, int] 
+            h = get_iterator(hidden_dims[i], len(group_index))
+            if merge[i]:
+                in_dim = h[0] if merge_dim[i] is None else merge_dim[i]
+                group_index = [range(in_dim)]
+            else:
+                in_dim = sum(h)
+                group_index = []
+                cnt = 0
+                for tmp in h:
+                    group_index.append(range(cnt,cnt+tmp))
+                    cnt += tmp
+                    
+    def reset_graph(self, graph=None):
+        for i in range(self.num_layers):
+            getattr(self.layers, str(i)).reset_graph(graph)
+        self.graph = graph
+                    
+    def reset_out_indices(self, out_indices=None):
+        num_layers = len(self.hidden_dims)
+        out_indices = get_iterator(out_indices, num_layers)
+        for i in range(num_layers):
+            getattr(self.layers, str(i)).reset_out_indices(out_indices[i])
+        self.out_indices = out_indices
+                    
+    def forward(self, x):
+        if self.reset_graph_every_forward:
+            self.reset_graph()
+            
+        return self.layers(x)
\ No newline at end of file