Diff of /simutils.py [000000] .. [39fb2b]

Switch to unified view

a b/simutils.py
1
import torch
2
import torch.nn as nn
3
from torch.nn.parameter import Parameter
4
5
class LinearRegressionModel(nn.Module):
6
    def __init__(self, p, weights = None, bias = None):
7
        super(LinearRegressionModel, self).__init__()
8
        self.linear = nn.Linear(p, 1)
9
        if weights is not None:
10
            self.linear.weight = Parameter(torch.Tensor([weights]))
11
        if bias is not None:
12
            self.linear.bias = Parameter(torch.Tensor([bias]))
13
14
    def forward(self, x):
15
        return self.linear(x)
16
17
class LogisticRegressionModel(nn.Module):
18
    def __init__(self, p, weights = None, bias = None):
19
        super(LogisticRegressionModel, self).__init__()
20
        self.linear = nn.Linear(p, 1)
21
        if weights is not None:
22
            self.linear.weight = Parameter(torch.Tensor([weights]))
23
        if bias is not None:
24
            self.linear.bias = Parameter(torch.Tensor([bias]))
25
26
    def forward(self, x):
27
        return torch.sigmoid(self.linear(x))
28
29
# model_modules["Logistic"](3, (1,1,1), 0).forward(torch.zeros([1,3]))
30