a b/gen_unfold_template/gen_adaptivesurfs.py
1
import numpy as np
2
import scipy
3
import nibabel as nib
4
import matplotlib.pyplot as plt
5
from scipy.spatial import Delaunay
6
import matplotlib
7
matplotlib.use('Agg')
8
from scipy.interpolate import RegularGridInterpolator
9
10
11
# this script creates a mesh with triangulation density locally-adaptive based on an average surface area.
12
13
# the original grid spacing was 128 points over 20mm (or 20/128)
14
15
# since surf vertex area is proportional to grid spacing squared,
16
# if S_target and S_in are the target and input surface areas, then,
17
# adjustment to get corrected grid spacing: 
18
19
# (20/N)^2 = (S_target / S_in)  * (20/128)^2
20
21
# equiv to:
22
# N = 128 * sqrt(S_in / S_target)
23
24
25
# the approach this script takes is to create a set of grids at diff resolutions, and pick the grid points based on the binned input surface area..
26
27
# all the points from the different resolutions are then triangulated
28
29
target_surfarea = snakemake.params.targetarea
30
num_bins = snakemake.params.nbins
31
32
n_ap = snakemake.config['in_surfarea']['dims'][0]
33
n_pd = snakemake.config['in_surfarea']['dims'][1]
34
start_ap = snakemake.config['in_surfarea']['start'][0]
35
end_ap = snakemake.config['in_surfarea']['end'][0]
36
start_pd = snakemake.config['in_surfarea']['start'][1]
37
end_pd = snakemake.config['in_surfarea']['end'][1]
38
39
N_pd = n_pd+2 # i.e. 128 for hipp, 32 for dentate
40
aspect_ratio = int((n_ap+2) / (n_pd+2)) # e.g. 256/128 = 2 for hipp, and 256/32 = 5 for dentate
41
42
#load average surface area metric
43
surfarea_gii = nib.load(snakemake.input.surfarea_gii)
44
arr_surfarea = surfarea_gii.get_arrays_from_intent('NIFTI_INTENT_NORMAL')[0].data.reshape(n_pd,n_ap)
45
46
47
#replace nan with 0.01 (actually, shouldn't be any nans with the latest workflow) 
48
arr_surfarea = np.nan_to_num(arr_surfarea,nan=0.01)
49
50
#bin the image into discrete regions using histogram
51
(histval,histedges) = np.histogram(arr_surfarea.flat,bins=num_bins)
52
binned_area = np.zeros(arr_surfarea.shape)
53
54
for i in range(len(histedges)-1):
55
56
    masked = np.logical_and(arr_surfarea>=histedges[i],arr_surfarea<histedges[i+1])
57
    binned_area[masked] = i
58
59
print(binned_area.shape)
60
61
62
63
# create an interpolator to sample the surfarea on each new grid
64
65
orig_x = np.linspace(start_ap,end_ap,n_ap) #since the original flat space was offset a bit from 0-40,0-20
66
orig_y = np.linspace(start_pd,end_pd,n_pd)
67
print(orig_x)
68
print(orig_y)
69
70
#now, for each multi-res grid, interpolate the bin value
71
interpolator = RegularGridInterpolator((orig_y, orig_x), binned_area,method='nearest')
72
73
74
75
76
multires_points = list()
77
78
#create grid for each bin centre
79
for i in range(len(histedges)-1):
80
    
81
    in_surfarea = (histedges[i] + histedges[i+1]) * 0.5 # histogram bin centre 
82
    #bin_centre is the mean vertex area for those vertices
83
84
85
    # adjustment to get corrected grid spacing (1/N)
86
    # 1/N = (S_target / S_in)^2  * (1/128)
87
88
    # equiv to:
89
    # N = 128 * sqrt(S_in / S_target)
90
91
    print(f'input surf area at bin {i}: {in_surfarea}')
92
93
    
94
    N = int(N_pd * np.sqrt(in_surfarea/ target_surfarea))
95
    #N = int(snakemake.params.scaling_factor * np.power(in_surfarea / target_surfarea,snakemake.params.power_factor))
96
    print(f'N for gridding is: {N}')
97
    print(N)
98
99
    nx, ny = (aspect_ratio*N,N)
100
    x = np.linspace(start_ap,end_ap, nx)
101
    y = np.linspace(start_pd,end_pd, ny)
102
    #print(x)
103
    #print(y)
104
    xv, yv = np.meshgrid(y,x)
105
106
   # print(xv.max())
107
   # print(yv.max())
108
    binval = interpolator((xv,yv))
109
110
    #now get points for the binval
111
    #coords = np.unravel_index(np.where(binval==i),xv.shape)
112
    #print(coords)
113
114
    coords_x = xv[np.where(binval==i)]
115
   # print(coords_x)
116
117
    coords_y = yv[np.where(binval==i)]
118
   # print(coords_y)
119
120
    multires_points.append(np.transpose(np.vstack((coords_x,coords_y))))
121
122
123
all_points = np.vstack(multires_points)
124
print(all_points.shape)
125
126
#plot a grid
127
tri = Delaunay(all_points)
128
plt.figure(figsize=(50,100))
129
plt.triplot(all_points[:,0], all_points[:,1], tri.simplices)
130
plt.plot(all_points[:,0], all_points[:,1], '.')
131
plt.savefig(snakemake.output.grid_png)
132
133
134
#write the vertices and triangles as csv 
135
#so we can create gifti from matlab
136
137
nverts = all_points.shape[0]
138
print(f'nverts: {nverts}')
139
nbins = num_bins
140
vertices_fname = snakemake.output.points_csv
141
triangles_fname = snakemake.output.triangles_csv
142
143
144
np.savetxt(vertices_fname,tri.points,fmt='%f')
145
np.savetxt(triangles_fname,tri.simplices,fmt='%u')
146