Switch to unified view

a b/sub-packages/bionemo-webdatamodule/README.md
1
# bionemo-webdatamodule
2
3
To install, execute the following:
4
```bash
5
pip install -e .
6
```
7
8
To run unit tests, execute:
9
```bash
10
pytest -v .
11
```
12
13
## WebDataModule
14
15
```python
16
class WebDataModule(L.LightningDataModule)
17
```
18
19
A LightningDataModule for using webdataset tar files.
20
21
`WebDataModule` is a `LightningDataModule` for using webdataset tar files to setup PyTorch
22
datasets and dataloaders. This data module takes as input a dictionary: Split -> tar file
23
directory and vaiours webdataset config settings. In its setup() function, it creates the
24
webdataset object chaining up the input `pipeline_wds` workflow. In its train/val/test_dataloader(),
25
it creates the WebLoader object chaining up the `pipeline_prebatch_wld` workflow.
26
27
**Examples**:
28
29
  --------
30
  1. create the data module with input directory to webdataset tar files.
31
  Depending on which of the downstream Lightning.Trainer methods are called,
32
  e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or
33
  `Trainer.predict()`, only a subset of the train, val and test splits need to
34
  be specified in the various input options to the data module:
35
36
  - `Trainer.fit()` requires the `train` and `val` splits
37
  - `Trainer.validate()` requires the `val` split
38
  - `Trainer.test()` requires the `test` splits
39
  - `Trainer.predict()` requires the `test` splits
40
41
  Here is an example of constructing the data module for `Trainer.fit()`:
42
```python
43
>>> from bionemo.webdatamodule.datamodule import Split, WebDataModule
44
>>>
45
>>> tar_file_prefix = "shards"
46
>>>
47
>>> dirs_of_tar_files = {
48
>>>     Split.train: "/path/to/train/split/tars",
49
>>>     Split.val: "/path/to/val/split/tars",
50
>>> }
51
>>>
52
>>> n_samples {
53
>>>     Split.train: 1000,
54
>>>     Split.val: 100,
55
>>> }
56
>>>
57
>>> # this is the string to retrieve the corresponding data object from the
58
>>> # webdataset file (see
59
>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format
60
>>> # for details)
61
>>> suffix_keys_wds = "tensor.pyd"
62
>>>
63
>>> seed = 27193781
64
>>>
65
>>> # Specify the routines to process the samples in the WebDataset object.
66
>>> # The routine is a generator of an Iterable of generators that are chained
67
>>> # together by nested function calling. The following is equivalent of
68
>>> # defining a overall generator of `shuffle(untuple(...))` which
69
>>> # untuples the samples and shuffles them. See webdataset's Documentation
70
>>> # for details.
71
>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's
72
>>> # file parsing rule.
73
>>>
74
>>> untuple = lambda source : (sample for (sample,) in source)
75
>>>
76
>>> from webdatast import shuffle
77
>>> pipeline_wds = {
78
>>>     Split.train : [untuple, shuffle(n_samples[Split.train],
79
>>>                                     rng=random.Random(seed_rng_shfl))],
80
>>>     Split.val: untuple
81
>>> }
82
>>>
83
>>> # Similarly the user can optionally define the processing routine on the
84
>>> # WebLoader (the dataloader of webdataset).
85
>>> # NOTE: these routines by default take unbatched sample as input so the
86
>>> # user can customize their batching routines here
87
>>>
88
>>> batch = batched(local_batch_size, collation_fn=lambda
89
                    list_samples : torch.vstack(list_samples))
90
>>> pipeline_prebatch_wld = {
91
        Split.train: [shuffle(n_samples[Split.train],
92
                              rng=random.Random(seed_rng_shfl)), batch],
93
        Split.val : batch,
94
        Split.test : batch
95
    }
96
>>>
97
>>> # the user can optionally specify the kwargs for WebDataset and
98
>>> # WebLoader
99
>>>
100
>>> kwargs_wds = {
101
>>>     split : {'shardshuffle' : split == Split.train,
102
>>>              'nodesplitter' : wds.split_by_node,
103
>>>              'seed' : seed_rng_shfl}
104
>>>     for split in Split
105
>>>     }
106
>>>
107
>>> kwargs_wld = {
108
>>>     split : {"num_workers": 2} for split in Split
109
>>>     }
110
>>>
111
>>> invoke_wds = {
112
>>>     split: [("with_epoch", {"nbatches" : 5})] for split in Split
113
>>>     }
114
>>>
115
>>> invoke_wld = {
116
>>>     split: [("with_epoch", {"nbatches" : 5}] for split in Split
117
>>>     }
118
>>>
119
>>> # construct the data module
120
>>> data_module = WebDataModule(suffix_keys_wds,
121
                                dirs_of_tar_files,
122
                                prefix_tars_wds=tar_file_prefix,
123
                                pipeline_wds=pipeline_wds,
124
                                pipeline_prebatch_wld=pipeline_prebatch_wld,
125
                                kwargs_wds=kwargs_wds,
126
                                kwargs_wld=kwargs_wld,
127
                                invoke_wds=invoke_wds,
128
                                invoke_wld=invoke_wld,
129
                                )
130
```
131
132
<a id="datamodule.WebDataModule.__init__"></a>
133
134
#### \_\_init\_\_
135
136
```python
137
def __init__(
138
    suffix_keys_wds: Union[str, Iterable[str]],
139
    dirs_tars_wds: Dict[Split, str],
140
    prefix_tars_wds: str = "wdshards",
141
    pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]],
142
                                             Iterable[Any]]]] = None,
143
    pipeline_prebatch_wld: Optional[Dict[Split, Union[Iterable[Iterable[Any]],
144
                                                      Iterable[Any]]]] = None,
145
    kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None,
146
    kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None,
147
    invoke_wds: Optional[Dict[Split, List[Tuple[str, Dict[str, Any]]]]] = None,
148
    invoke_wld: Optional[Dict[Split, List[Tuple[str, Dict[str,
149
                                                          Any]]]]] = None)
150
```
151
152
Constructor.
153
154
**Arguments**:
155
156
- `suffix_keys_wds` - a set of keys each
157
  corresponding to a data object in the webdataset tar file
158
  dictionary. The data objects of these keys will be extracted and
159
  tupled for each sample in the tar files
160
- `dirs_tars_wds` - input dictionary: Split -> tar file
161
  directory that contains the webdataset tar files for each split
162
  Kwargs:
163
- `prefix_tars_wds` - name prefix of the input webdataset tar
164
  files. The input tar files are globbed by
165
  "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar"
166
- `pipeline_wds` - a dictionary of webdatast composable, i.e.,
167
  functor that maps a iterator to another iterator that
168
  transforms the data sample yield from the dataset object, for
169
  different splits, or an iterable to such a sequence of such
170
  iterators. For example, this can be used to transform the
171
  sample in the worker before sending it to the main process of
172
  the dataloader
173
- `pipeline_prebatch_wld` - a dictionary
174
  of webloader composable, i.e., functor that maps a iterator to
175
  another iterator that transforms the data sample yield from the
176
  WebLoader object, for different splits, or an iterable to a
177
  seuqnence of such iterators. For example, this can be used for
178
  batching the samples. NOTE: this is applied before batching is
179
  yield from the WebLoader
180
- `kwargs_wds` - kwargs for the WebDataset.__init__()
181
  kwargs_wld : kwargs for the WebLoader.__init__(), e.g., num_workers, of each split
182
- `invoke_wds` - a dictionary of WebDataset methods to be called upon WebDataset
183
  construction. These methods must return the WebDataset object itself. Examples
184
  are .with_length() and .with_epoch(). These methods will be applied towards
185
  the end of returning the WebDataset object, i.e., after the pipline_wds
186
  have been applied. The inner list of tuples each has its first element as the
187
  method name and the second element as the corresponding method's kwargs.
188
- `invoke_wld` - a dictionary of WebLoader methods to be called upon WebLoader
189
  construction. These methods must return the WebLoader object itself. Examples
190
  are .with_length() and .with_epoch(). These methods will be applied towards
191
  the end of returning the WebLoader object, i.e., after the pipelin_prebatch_wld
192
  have been applied. The inner list of tuples each has its first element as the
193
  method name and the second element as the corresponding method's kwargs.
194
195
<a id="datamodule.WebDataModule.prepare_data"></a>
196
197
#### prepare\_data
198
199
```python
200
def prepare_data() -> None
201
```
202
203
This is called only by the main process by the Lightning workflow.
204
205
Do not rely on this data module object's state update here as there is no
206
way to communicate the state update to other subprocesses. Is a **no-op**.
207
208
<a id="datamodule.WebDataModule.setup"></a>
209
210
#### setup
211
212
```python
213
def setup(stage: str) -> None
214
```
215
216
This is called on all Lightning-managed nodes in a multi-node training session.
217
218
**Arguments**:
219
220
- `stage` - "fit", "test" or "predict"
221
222
<a id="datamodule.WebDataModule.train_dataloader"></a>
223
224
#### train\_dataloader
225
226
```python
227
def train_dataloader() -> wds.WebLoader
228
```
229
230
Webdataset for the training data.
231
232
<a id="datamodule.WebDataModule.val_dataloader"></a>
233
234
#### val\_dataloader
235
236
```python
237
def val_dataloader() -> wds.WebLoader
238
```
239
240
Webdataset for the validation data.
241
242
<a id="datamodule.WebDataModule.test_dataloader"></a>
243
244
#### test\_dataloader
245
246
```python
247
def test_dataloader() -> wds.WebLoader
248
```
249
250
Webdataset for the test data.
251
252
<a id="datamodule.WebDataModule.predict_dataloader"></a>
253
254
#### predict\_dataloader
255
256
```python
257
def predict_dataloader() -> wds.WebLoader
258
```
259
260
Alias for :func:`test_dataloader`.
261
262
<a id="datamodule.PickledDataWDS"></a>
263
264
## PickledDataWDS Objects
265
266
```python
267
class PickledDataWDS(WebDataModule)
268
```
269
270
A LightningDataModule to process pickled data into webdataset tar files.
271
272
`PickledDataWDS` is a LightningDataModule to process pickled data into webdataset tar files
273
and setup dataset and dataloader. This inherits the webdataset setup from its parent module
274
`WebDataModule`. This data module takes a directory of pickled data files, data filename
275
prefixes for train/val/test splits, data filename suffixes and prepare webdataset tar files
276
by globbing the specific pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}`
277
and outputing to webdataset tar file with the dict structure:
278
NOTE: this assumes only one pickled file is processed for each sample. In
279
its setup() function, it creates the webdataset object chaining up the input
280
`pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the
281
WebLoader object chaining up the `pipeline_prebatch_wld` workflow.
282
283
```
284
    {"__key__" : name.replace(".", "-"),
285
     suffix_pickles : pickled.dumps(data) }
286
```
287
288
**Examples**:
289
290
  --------
291
  1. create the data module with a directory of pickle files and the file name
292
  prefix thereof for different splits to used by `Lightning.Trainer.fit()`
293
294
```python
295
>>> from bionemo.core.data.datamodule import Split, PickledDataWDS
296
297
>>> dir_pickles = "/path/to/my/pickles/dir"
298
299
>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the
300
>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the
301
>>> # validation dataset
302
303
>>> suffix_pickles = "mydata.pt"
304
305
>>> names_subset = {
306
>>>     Split.train: [sample1, sample2],
307
>>>     Split.val: [sample4, sample5],
308
>>> }
309
310
>>> # the following setting will attempt to create at least 5 tar files in
311
>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar`
312
313
>>> n_tars_wds = 5
314
>>> prefix_tars_wds = "myshards"
315
>>> output_dir_tar_files = {
316
        Split.train : "/path/to/output/tars/dir-train",
317
        Split.val : "/path/to/output/tars/dir-val",
318
        Split.test : "/path/to/output/tars/dir-test",
319
    }
320
321
>>> # user can optionally customize the data processing routines and kwargs used
322
>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`)
323
324
>>> pipeline_wds = { Split.train: ... }
325
326
>>> pipeline_prebatch_wld = { Split.train: ... }
327
328
>>> kwargs_wds = { Split.train: ..., Split.val: ... }
329
330
>>> kwargs_wld = { Split.train: ..., Split.val: ... }
331
332
>>> invoke_wds = { Split.train: ..., Split.val: ... }
333
334
>>> invoke_wld = { Split.train: ..., Split.val: ... }
335
336
>>> # create the data module
337
>>> data_module = PickledDataWDS(
338
>>>     dir_pickles,
339
>>>     names_subset,
340
>>>     suffix_pickles, # `WebDataModule` args
341
>>>     output_dir_tar_files, # `WebDataModule` args
342
>>>     n_tars_wds=n_tars_wds,
343
>>>     prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs
344
>>>     pipeline_wds=pipeline_wds, # `WebDataModule` kwargs
345
>>>     pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs
346
>>>     kwargs_wds=kwargs_wds, # `WebDataModule` kwargs
347
>>>     kwargs_wld=kwargs_wld, # `WebDataModule` kwargs
348
>>>     invoke_wds=invoke_wds, # `WebDataModule` kwargs
349
>>>     invoke_wld=invoke_wld, # `WebDataModule` kwargs
350
>>> )
351
```
352
353
<a id="datamodule.PickledDataWDS.__init__"></a>
354
355
#### \_\_init\_\_
356
357
```python
358
def __init__(dir_pickles: str,
359
             names_subset: Dict[Split, List[str]],
360
             *args,
361
             n_tars_wds: Optional[int] = None,
362
             **kwargs) -> None
363
```
364
365
Constructor.
366
367
**Arguments**:
368
369
- `dir_pickles` - input directory of pickled data files
370
- `names_subset` - list of filename prefix of
371
  the data samples to be loaded in the dataset and dataloader for
372
  each of the split
373
- `*args` - arguments passed to the parent WebDataModule
374
- `n_tars_wds` - attempt to create at least this number of
375
  webdataset shards
376
- `**kwargs` - arguments passed to the parent WebDataModule
377
378
<a id="datamodule.PickledDataWDS.prepare_data"></a>
379
380
#### prepare\_data
381
382
```python
383
def prepare_data() -> None
384
```
385
386
This is called only by the main process by the Lightning workflow.
387
388
Do not rely on this data module object's state update here as there is no
389
way to communicate the state update to other subprocesses. The nesting
390
`pickles_to_tars` function goes through the data name prefixes in the
391
different splits, read the corresponding pickled file and output a
392
webdataset tar archive with the dict structure: {"__key__" :
393
name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }.