|
a |
|
b/SAC/utils.py |
|
|
1 |
import numpy as np |
|
|
2 |
from pyquaternion import Quaternion |
|
|
3 |
import math |
|
|
4 |
import pandas as pd |
|
|
5 |
import matplotlib.pyplot as pp |
|
|
6 |
import skvideo.io |
|
|
7 |
import os |
|
|
8 |
import pickle |
|
|
9 |
from copy import deepcopy |
|
|
10 |
import torch |
|
|
11 |
|
|
|
12 |
def create_log_gaussian(mean, log_std, t): |
|
|
13 |
quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2)) |
|
|
14 |
l = mean.shape |
|
|
15 |
log_z = log_std |
|
|
16 |
z = l[-1] * math.log(2 * math.pi) |
|
|
17 |
log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z |
|
|
18 |
return log_p |
|
|
19 |
|
|
|
20 |
def logsumexp(inputs, dim=None, keepdim=False): |
|
|
21 |
if dim is None: |
|
|
22 |
inputs = inputs.view(-1) |
|
|
23 |
dim = 0 |
|
|
24 |
s, _ = torch.max(inputs, dim=dim, keepdim=True) |
|
|
25 |
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() |
|
|
26 |
if not keepdim: |
|
|
27 |
outputs = outputs.squeeze(dim) |
|
|
28 |
return outputs |
|
|
29 |
|
|
|
30 |
def soft_update(target, source, tau): |
|
|
31 |
for target_param, param in zip(target.parameters(), source.parameters()): |
|
|
32 |
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) |
|
|
33 |
|
|
|
34 |
def hard_update(target, source): |
|
|
35 |
for target_param, param in zip(target.parameters(), source.parameters()): |
|
|
36 |
target_param.data.copy_(param.data) |