Diff of /SAC/utils.py [000000] .. [9f010e]

Switch to unified view

a b/SAC/utils.py
1
import numpy as np
2
from pyquaternion import Quaternion
3
import math
4
import pandas as pd
5
import matplotlib.pyplot as pp
6
import skvideo.io
7
import os
8
import pickle
9
from copy import deepcopy
10
import torch
11
12
def create_log_gaussian(mean, log_std, t):
13
    quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
14
    l = mean.shape
15
    log_z = log_std
16
    z = l[-1] * math.log(2 * math.pi)
17
    log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
18
    return log_p
19
20
def logsumexp(inputs, dim=None, keepdim=False):
21
    if dim is None:
22
        inputs = inputs.view(-1)
23
        dim = 0
24
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
25
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
26
    if not keepdim:
27
        outputs = outputs.squeeze(dim)
28
    return outputs
29
30
def soft_update(target, source, tau):
31
    for target_param, param in zip(target.parameters(), source.parameters()):
32
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
33
34
def hard_update(target, source):
35
    for target_param, param in zip(target.parameters(), source.parameters()):
36
        target_param.data.copy_(param.data)