Diff of /SAC/utils.py [000000] .. [9f010e]

Switch to side-by-side view

--- a
+++ b/SAC/utils.py
@@ -0,0 +1,36 @@
+import numpy as np
+from pyquaternion import Quaternion
+import math
+import pandas as pd
+import matplotlib.pyplot as pp
+import skvideo.io
+import os
+import pickle
+from copy import deepcopy
+import torch
+
+def create_log_gaussian(mean, log_std, t):
+    quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
+    l = mean.shape
+    log_z = log_std
+    z = l[-1] * math.log(2 * math.pi)
+    log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
+    return log_p
+
+def logsumexp(inputs, dim=None, keepdim=False):
+    if dim is None:
+        inputs = inputs.view(-1)
+        dim = 0
+    s, _ = torch.max(inputs, dim=dim, keepdim=True)
+    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
+    if not keepdim:
+        outputs = outputs.squeeze(dim)
+    return outputs
+
+def soft_update(target, source, tau):
+    for target_param, param in zip(target.parameters(), source.parameters()):
+        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
+
+def hard_update(target, source):
+    for target_param, param in zip(target.parameters(), source.parameters()):
+        target_param.data.copy_(param.data)