Diff of /simutils.py [000000] .. [39fb2b]

Switch to side-by-side view

--- a
+++ b/simutils.py
@@ -0,0 +1,30 @@
+import torch
+import torch.nn as nn
+from torch.nn.parameter import Parameter
+
+class LinearRegressionModel(nn.Module):
+    def __init__(self, p, weights = None, bias = None):
+        super(LinearRegressionModel, self).__init__()
+        self.linear = nn.Linear(p, 1)
+        if weights is not None:
+            self.linear.weight = Parameter(torch.Tensor([weights]))
+        if bias is not None:
+            self.linear.bias = Parameter(torch.Tensor([bias]))
+
+    def forward(self, x):
+        return self.linear(x)
+
+class LogisticRegressionModel(nn.Module):
+    def __init__(self, p, weights = None, bias = None):
+        super(LogisticRegressionModel, self).__init__()
+        self.linear = nn.Linear(p, 1)
+        if weights is not None:
+            self.linear.weight = Parameter(torch.Tensor([weights]))
+        if bias is not None:
+            self.linear.bias = Parameter(torch.Tensor([bias]))
+
+    def forward(self, x):
+        return torch.sigmoid(self.linear(x))
+
+# model_modules["Logistic"](3, (1,1,1), 0).forward(torch.zeros([1,3]))
+