a b/lit_gpt/config.py
1
import json
2
from copy import deepcopy
3
from dataclasses import dataclass, field
4
from pathlib import Path
5
from typing import Any, Literal, Optional, Type, Union
6
7
import torch
8
from typing_extensions import Self
9
10
import lit_gpt.model
11
from lit_gpt.utils import find_multiple
12
13
14
@dataclass
15
class Config:
16
    name: str = ""
17
    hf_config: dict = field(default_factory=dict)
18
    block_size: int = 4096
19
    vocab_size: int = 50254
20
    padding_multiple: int = 512
21
    padded_vocab_size: Optional[int] = None
22
    n_layer: int = 16
23
    n_head: int = 32
24
    n_embd: int = 4096
25
    rotary_percentage: float = 0.25
26
    parallel_residual: bool = True
27
    bias: bool = True
28
    lm_head_bias: bool = False
29
    # to use multi-head attention (MHA), set this to `n_head` (default)
30
    # to use multi-query attention (MQA), set this to 1
31
    # to use grouped-query attention (GQA), set this to a value in between
32
    # Example with `n_head=4`
33
    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
34
    # │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
35
    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
36
    #   │    │    │    │         │        │                 │
37
    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
38
    # │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
39
    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
40
    #   │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
41
    # ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
42
    # │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
43
    # └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
44
    # ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
45
    #         MHA                    GQA                   MQA
46
    #   n_query_groups=4       n_query_groups=2      n_query_groups=1
47
    #
48
    # credit https://arxiv.org/pdf/2305.13245.pdf
49
    n_query_groups: Optional[int] = None
50
    shared_attention_norm: bool = False
51
    _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
52
    norm_eps: float = 1e-5
53
    _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
54
    gelu_approximate: str = "none"
55
    intermediate_size: Optional[int] = None
56
    rope_condense_ratio: int = 1
57
    rope_base: int = 10000
58
59
    def __post_init__(self):
60
        if not self.name:
61
            self.name = self.hf_config.get("name", self.name)
62
63
        assert self.n_embd % self.n_head == 0
64
        self.head_size = self.n_embd // self.n_head
65
66
        # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
67
        if self.padded_vocab_size is None:
68
            self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
69
        else:
70
            # vocab size shouldn't be larger than padded vocab size
71
            self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
72
73
        # compute the number of query groups
74
        if self.n_query_groups is not None:
75
            assert self.n_head % self.n_query_groups == 0
76
        else:
77
            self.n_query_groups = self.n_head
78
79
        # compute the intermediate size for MLP if not set
80
        if self.intermediate_size is None:
81
            if self._mlp_class == "LLaMAMLP":
82
                raise ValueError("The config needs to set the `intermediate_size`")
83
            self.intermediate_size = 4 * self.n_embd
84
85
        self.rope_n_elem = int(self.rotary_percentage * self.head_size)
86
87
    @classmethod
88
    def from_name(cls, name: str, **kwargs: Any) -> Self:
89
        if name not in name_to_config:
90
            # search through all `config['hf_config']['name']`
91
            try:
92
                conf_dict = next(config for config in configs if name == config["hf_config"]["name"])
93
            except StopIteration:
94
                raise ValueError(f"{name!r} is not a supported config name")
95
        else:
96
            conf_dict = name_to_config[name]
97
98
        conf_dict = conf_dict.copy()
99
        if "condense_ratio" in kwargs:  # legacy name
100
            kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
101
        conf_dict.update(kwargs)
102
        return cls(**conf_dict)
103
104
    @classmethod
105
    def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
106
        with open(path, encoding="utf-8") as fp:
107
            json_kwargs = json.load(fp)
108
        if "condense_ratio" in json_kwargs:  # legacy name
109
            json_kwargs["rope_condense_ratio"] = json_kwargs.pop("condense_ratio")
110
        if "condense_ratio" in kwargs:  # legacy name
111
            kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
112
        if "org" in json_kwargs:  # legacy name
113
            json_kwargs["hf_config"] = {"name": json_kwargs["name"], "org": json_kwargs.pop("org")}
114
        if "org" in kwargs:  # legacy name
115
            kwargs["hf_config"] = {"name": kwargs.get("name", json_kwargs["name"]), "org": kwargs.pop("org")}
116
        json_kwargs.update(kwargs)
117
        return cls(**json_kwargs)
118
119
    @classmethod
120
    def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
121
        """Automatically load `lit_config.json` and if it doesn't exist - a matching config from `lit_gpt/config.py`."""
122
        if (config_path := path / "lit_config.json").is_file():
123
            return cls.from_json(config_path, **kwargs)
124
        if (model_name := path.name) in name_to_config:
125
            return cls.from_name(model_name, **kwargs)
126
        raise FileNotFoundError(f"For {str(path)!r} neither 'lit_config.json' nor matching config exists.")
127
128
    @property
129
    def mlp_class(self) -> Type:
130
        # `self._mlp_class` cannot be the type to keep the config json serializable
131
        return getattr(lit_gpt.model, self._mlp_class)
132
133
    @property
134
    def norm_class(self) -> Type:
135
        # `self._norm_class` cannot be the type to keep the config json serializable
136
        if self._norm_class == "RMSNorm":
137
            from lit_gpt.rmsnorm import RMSNorm
138
139
            return RMSNorm
140
        return getattr(torch.nn, self._norm_class)
141
142
143
########################
144
# Stability AI StableLM
145
########################
146
configs = [
147
    # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
148
    dict(name="stablelm-base-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-base-alpha-3b")),
149
    # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
150
    dict(
151
        name="stablelm-base-alpha-7b",
152
        hf_config=dict(org="stabilityai", name="stablelm-base-alpha-7b"),
153
        n_head=48,
154
        n_embd=6144,
155
        padding_multiple=256,
156
    ),
157
    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
158
    dict(name="stablelm-tuned-alpha-3b", hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-3b"), n_head=32),
159
    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
160
    dict(
161
        name="stablelm-tuned-alpha-7b",
162
        hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-7b"),
163
        n_head=48,
164
        n_embd=6144,
165
        padding_multiple=256,
166
    ),
167
]
168
169
####################
170
# EleutherAI Pythia
171
####################
172
pythia = [
173
    # https://huggingface.co/EleutherAI/pythia-14m/blob/main/config.json
174
    dict(
175
        name="pythia-14m",
176
        hf_config=dict(org="EleutherAI", name="pythia-14m"),
177
        block_size=512,
178
        n_layer=6,
179
        n_embd=128,
180
        n_head=4,
181
        padding_multiple=128,
182
    ),
183
    # https://huggingface.co/EleutherAI/pythia-31m/blob/main/config.json
184
    dict(
185
        name="pythia-31m",
186
        hf_config=dict(org="EleutherAI", name="pythia-31m"),
187
        block_size=1024,
188
        n_layer=6,
189
        n_embd=256,
190
        n_head=8,
191
        padding_multiple=128,
192
    ),
193
    # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
194
    dict(
195
        name="pythia-70m",
196
        hf_config=dict(org="EleutherAI", name="pythia-70m"),
197
        block_size=2048,
198
        n_layer=6,
199
        n_embd=512,
200
        n_head=8,
201
        padding_multiple=128,
202
    ),
203
    # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
204
    dict(
205
        name="pythia-160m",
206
        hf_config=dict(org="EleutherAI", name="pythia-160m"),
207
        block_size=2048,
208
        n_layer=12,
209
        n_embd=768,
210
        n_head=12,
211
        padding_multiple=128,
212
    ),
213
    # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
214
    dict(
215
        name="pythia-410m",
216
        hf_config=dict(org="EleutherAI", name="pythia-410m"),
217
        block_size=2048,
218
        n_layer=24,
219
        n_embd=1024,
220
        n_head=16,
221
        padding_multiple=128,
222
    ),
223
    # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
224
    dict(
225
        name="pythia-1b",
226
        hf_config=dict(org="EleutherAI", name="pythia-1b"),
227
        block_size=2048,
228
        n_embd=2048,
229
        n_head=8,
230
        padding_multiple=128,
231
    ),
232
    # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
233
    dict(
234
        name="pythia-1.4b",
235
        hf_config=dict(org="EleutherAI", name="pythia-1.4b"),
236
        block_size=2048,
237
        n_layer=24,
238
        n_embd=2048,
239
        n_head=16,
240
        padding_multiple=128,
241
    ),
242
    # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
243
    dict(
244
        name="pythia-2.8b",
245
        hf_config=dict(org="EleutherAI", name="pythia-2.8b"),
246
        block_size=2048,
247
        n_layer=32,
248
        n_embd=2560,
249
        padding_multiple=128,
250
    ),
251
    # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
252
    dict(
253
        name="pythia-6.9b",
254
        hf_config=dict(org="EleutherAI", name="pythia-6.9b"),
255
        block_size=2048,
256
        n_layer=32,
257
        padding_multiple=256,
258
    ),
259
    # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
260
    dict(
261
        name="pythia-12b",
262
        hf_config=dict(org="EleutherAI", name="pythia-12b"),
263
        block_size=2048,
264
        n_layer=36,
265
        n_embd=5120,
266
        n_head=40,
267
    ),
268
]
269
configs.extend(pythia)
270
for c in pythia:
271
    # "pythia-14m" and "pythia-31m" don't have deduped version
272
    if c["name"] in ("pythia-14m", "pythia-31m"):
273
        continue
274
    copy = deepcopy(c)
275
    copy["name"] = f"{c['name']}-deduped"
276
    copy["hf_config"]["name"] = f"{c['hf_config']['name']}-deduped"
277
    configs.append(copy)
278
279
280
####################################
281
# togethercomputer RedPajama INCITE
282
####################################
283
redpajama_incite = [
284
    # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
285
    dict(
286
        name="RedPajama-INCITE-{}-3B-v1",
287
        hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-3B-v1"),
288
        block_size=2048,
289
        n_layer=32,
290
        n_embd=2560,
291
        padding_multiple=256,
292
        rotary_percentage=1.0,
293
        parallel_residual=False,
294
    ),
295
    # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
296
    dict(
297
        name="RedPajama-INCITE-7B-{}",
298
        hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-7B-{}"),
299
        block_size=2048,
300
        n_layer=32,
301
        padding_multiple=256,
302
        rotary_percentage=1.0,
303
        parallel_residual=False,
304
    ),
305
    # this redirects to the checkpoint above. kept for those who had the old weights already downloaded
306
    dict(
307
        name="RedPajama-INCITE-{}-7B-v0.1",
308
        hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-7B-v0.1"),
309
        block_size=2048,
310
        n_layer=32,
311
        padding_multiple=256,
312
        rotary_percentage=1.0,
313
        parallel_residual=False,
314
    ),
315
]
316
for c in redpajama_incite:
317
    for kind in ("Base", "Chat", "Instruct"):
318
        copy = deepcopy(c)
319
        copy["name"] = c["name"].format(kind)
320
        copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
321
        configs.append(copy)
322
323
324
#################
325
# TII UAE Falcon
326
#################
327
falcon = [
328
    # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
329
    dict(
330
        name="falcon-7b{}",
331
        hf_config=dict(org="tiiuae", name="falcon-7b{}"),
332
        block_size=2048,
333
        vocab_size=65024,
334
        padded_vocab_size=65024,
335
        n_layer=32,
336
        n_head=71,
337
        n_embd=4544,
338
        rotary_percentage=1.0,
339
        n_query_groups=1,
340
        bias=False,
341
        # this is not in the config, but in the original model implementation, only for this config
342
        shared_attention_norm=True,
343
    ),
344
    # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
345
    dict(
346
        name="falcon-40b{}",
347
        hf_config=dict(org="tiiuae", name="falcon-40b{}"),
348
        block_size=2048,
349
        vocab_size=65024,
350
        padded_vocab_size=65024,
351
        n_layer=60,
352
        n_head=128,
353
        n_embd=8192,
354
        rotary_percentage=1.0,
355
        n_query_groups=8,
356
        bias=False,
357
    ),
358
]
359
for c in falcon:
360
    for kind in ("", "-instruct"):
361
        copy = deepcopy(c)
362
        copy["name"] = c["name"].format(kind)
363
        copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
364
        configs.append(copy)
365
366
# https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json
367
falcon180b = dict(
368
    name="falcon-180B{}",
369
    hf_config=dict(org="tiiuae", name="falcon-180B{}"),
370
    block_size=2048,
371
    vocab_size=65024,
372
    padded_vocab_size=65024,
373
    n_layer=80,
374
    n_head=232,
375
    n_embd=14848,
376
    rotary_percentage=1.0,
377
    n_query_groups=8,
378
    bias=False,
379
)
380
381
for kind in ("", "-chat"):
382
    copy = deepcopy(falcon180b)
383
    copy["name"] = falcon180b["name"].format(kind)
384
    copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind)
385
    configs.append(copy)
386
387
388
#############################
389
# OpenLM Research Open LLaMA
390
#############################
391
open_LLaMA = [
392
    # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
393
    dict(
394
        name="open_llama_3b",
395
        hf_config=dict(org="openlm-research", name="open_llama_3b"),
396
        block_size=2048,
397
        vocab_size=32000,
398
        padding_multiple=64,
399
        n_layer=26,
400
        n_embd=3200,
401
        rotary_percentage=1.0,
402
        parallel_residual=False,
403
        bias=False,
404
        _norm_class="RMSNorm",
405
        norm_eps=1e-6,
406
        _mlp_class="LLaMAMLP",
407
        intermediate_size=8640,
408
    ),
409
    # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
410
    dict(
411
        name="open_llama_7b",
412
        hf_config=dict(org="openlm-research", name="open_llama_7b"),
413
        block_size=2048,
414
        vocab_size=32000,
415
        padding_multiple=64,
416
        n_layer=32,
417
        rotary_percentage=1.0,
418
        parallel_residual=False,
419
        bias=False,
420
        _norm_class="RMSNorm",
421
        norm_eps=1e-6,
422
        _mlp_class="LLaMAMLP",
423
        intermediate_size=11008,
424
    ),
425
    # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
426
    dict(
427
        name="open_llama_13b",
428
        hf_config=dict(org="openlm-research", name="open_llama_13b"),
429
        block_size=2048,
430
        vocab_size=32000,
431
        padding_multiple=64,
432
        n_layer=40,
433
        n_head=40,
434
        n_embd=5120,
435
        rotary_percentage=1.0,
436
        parallel_residual=False,
437
        bias=False,
438
        _norm_class="RMSNorm",
439
        norm_eps=1e-6,
440
        _mlp_class="LLaMAMLP",
441
        intermediate_size=13824,
442
    ),
443
]
444
configs.extend(open_LLaMA)
445
446
447
###############
448
# LMSYS Vicuna
449
###############
450
vicuna = [
451
    # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
452
    dict(
453
        name="vicuna-7b-v1.3",
454
        hf_config=dict(org="lmsys", name="vicuna-7b-v1.3"),
455
        block_size=2048,
456
        vocab_size=32000,
457
        padding_multiple=64,
458
        n_layer=32,
459
        rotary_percentage=1.0,
460
        parallel_residual=False,
461
        bias=False,
462
        _norm_class="RMSNorm",
463
        norm_eps=1e-6,
464
        _mlp_class="LLaMAMLP",
465
        intermediate_size=11008,
466
    ),
467
    # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
468
    dict(
469
        name="vicuna-13b-v1.3",
470
        hf_config=dict(org="lmsys", name="vicuna-13b-v1.3"),
471
        block_size=2048,
472
        vocab_size=32000,
473
        padding_multiple=64,
474
        n_layer=40,
475
        n_head=40,
476
        n_embd=5120,
477
        rotary_percentage=1.0,
478
        parallel_residual=False,
479
        bias=False,
480
        _norm_class="RMSNorm",
481
        norm_eps=1e-6,
482
        _mlp_class="LLaMAMLP",
483
        intermediate_size=13824,
484
    ),
485
    # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
486
    dict(
487
        name="vicuna-33b-v1.3",
488
        hf_config=dict(org="lmsys", name="vicuna-33b-v1.3"),
489
        block_size=2048,
490
        vocab_size=32000,
491
        padding_multiple=64,
492
        n_layer=60,
493
        n_head=52,
494
        n_embd=6656,
495
        rotary_percentage=1.0,
496
        parallel_residual=False,
497
        bias=False,
498
        _norm_class="RMSNorm",
499
        norm_eps=1e-6,
500
        _mlp_class="LLaMAMLP",
501
        intermediate_size=17920,
502
    ),
503
    # https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json
504
    dict(
505
        name="vicuna-7b-v1.5",
506
        hf_config=dict(org="lmsys", name="vicuna-7b-v1.5"),
507
        vocab_size=32000,
508
        padding_multiple=64,
509
        n_layer=32,
510
        rotary_percentage=1.0,
511
        parallel_residual=False,
512
        bias=False,
513
        _norm_class="RMSNorm",
514
        _mlp_class="LLaMAMLP",
515
        intermediate_size=11008,
516
    ),
517
    # https://huggingface.co/lmsys/vicuna-7b-v1.5-16k/blob/main/config.json
518
    dict(
519
        name="vicuna-7b-v1.5-16k",
520
        hf_config=dict(org="lmsys", name="vicuna-7b-v1.5-16k"),
521
        block_size=16384,
522
        vocab_size=32000,
523
        padding_multiple=64,
524
        n_layer=32,
525
        rotary_percentage=1.0,
526
        parallel_residual=False,
527
        bias=False,
528
        _norm_class="RMSNorm",
529
        _mlp_class="LLaMAMLP",
530
        intermediate_size=11008,
531
        rope_condense_ratio=4,
532
    ),
533
    # https://huggingface.co/lmsys/vicuna-13b-v1.5/blob/main/config.json
534
    dict(
535
        name="vicuna-13b-v1.5",
536
        hf_config=dict(org="lmsys", name="vicuna-13b-v1.5"),
537
        vocab_size=32000,
538
        padding_multiple=64,
539
        n_layer=40,
540
        n_head=40,
541
        n_embd=5120,
542
        rotary_percentage=1.0,
543
        parallel_residual=False,
544
        bias=False,
545
        _norm_class="RMSNorm",
546
        _mlp_class="LLaMAMLP",
547
        intermediate_size=13824,
548
    ),
549
    # https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json
550
    dict(
551
        name="vicuna-13b-v1.5-16k",
552
        hf_config=dict(org="lmsys", name="vicuna-13b-v1.5-16k"),
553
        block_size=16384,
554
        vocab_size=32000,
555
        padding_multiple=64,
556
        n_layer=40,
557
        n_head=40,
558
        n_embd=5120,
559
        rotary_percentage=1.0,
560
        parallel_residual=False,
561
        bias=False,
562
        _norm_class="RMSNorm",
563
        _mlp_class="LLaMAMLP",
564
        intermediate_size=13824,
565
        rope_condense_ratio=4,
566
    ),
567
]
568
configs.extend(vicuna)
569
570
571
#################
572
# LMSYS LongChat
573
#################
574
long_chat = [
575
    # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
576
    dict(
577
        name="longchat-7b-16k",
578
        hf_config=dict(org="lmsys", name="longchat-7b-16k"),
579
        block_size=16384,
580
        vocab_size=32000,
581
        padding_multiple=64,
582
        n_layer=32,
583
        rotary_percentage=1.0,
584
        parallel_residual=False,
585
        bias=False,
586
        _norm_class="RMSNorm",
587
        norm_eps=1e-6,
588
        _mlp_class="LLaMAMLP",
589
        intermediate_size=11008,
590
        rope_condense_ratio=8,
591
    ),
592
    # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
593
    dict(
594
        name="longchat-13b-16k",
595
        hf_config=dict(org="lmsys", name="longchat-13b-16k"),
596
        block_size=16384,
597
        vocab_size=32000,
598
        padding_multiple=64,
599
        n_layer=40,
600
        n_head=40,
601
        n_embd=5120,
602
        rotary_percentage=1.0,
603
        parallel_residual=False,
604
        bias=False,
605
        _norm_class="RMSNorm",
606
        norm_eps=1e-6,
607
        _mlp_class="LLaMAMLP",
608
        intermediate_size=13824,
609
        rope_condense_ratio=8,
610
    ),
611
]
612
configs.extend(long_chat)
613
614
615
######################
616
# NousResearch Hermes
617
######################
618
nous_research = [
619
    # https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b/blob/main/config.json
620
    dict(
621
        name="Nous-Hermes-llama-2-7b",
622
        hf_config=dict(org="NousResearch", name="Nous-Hermes-llama-2-7b"),
623
        padded_vocab_size=32000,
624
        n_layer=32,
625
        rotary_percentage=1.0,
626
        parallel_residual=False,
627
        bias=False,
628
        _norm_class="RMSNorm",
629
        norm_eps=1e-05,
630
        _mlp_class="LLaMAMLP",
631
        intermediate_size=11008,
632
    ),
633
    # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
634
    dict(
635
        name="Nous-Hermes-13b",
636
        hf_config=dict(org="NousResearch", name="Nous-Hermes-13b"),
637
        block_size=2048,
638
        vocab_size=32000,
639
        padded_vocab_size=32001,
640
        n_layer=40,
641
        n_head=40,
642
        n_embd=5120,
643
        rotary_percentage=1.0,
644
        parallel_residual=False,
645
        bias=False,
646
        _norm_class="RMSNorm",
647
        norm_eps=1e-6,
648
        _mlp_class="LLaMAMLP",
649
        intermediate_size=13824,
650
    ),
651
    # https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b
652
    dict(
653
        name="Nous-Hermes-Llama2-13b",
654
        hf_config=dict(org="NousResearch", name="Nous-Hermes-Llama2-13b"),
655
        vocab_size=32000,
656
        padded_vocab_size=32032,
657
        n_layer=40,
658
        n_head=40,
659
        n_embd=5120,
660
        rotary_percentage=1.0,
661
        parallel_residual=False,
662
        bias=False,
663
        _norm_class="RMSNorm",
664
        norm_eps=1e-05,
665
        _mlp_class="LLaMAMLP",
666
        intermediate_size=13824,
667
    ),
668
]
669
configs.extend(nous_research)
670
671
672
###############
673
# Meta LLaMA 2
674
###############
675
llama_2 = [
676
    # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
677
    dict(
678
        name="Llama-2-7b{}-hf",
679
        hf_config=dict(org="meta-llama", name="Llama-2-7b{}-hf"),
680
        vocab_size=32000,
681
        padding_multiple=64,
682
        n_layer=32,
683
        rotary_percentage=1.0,
684
        parallel_residual=False,
685
        bias=False,
686
        _norm_class="RMSNorm",
687
        _mlp_class="LLaMAMLP",
688
        intermediate_size=11008,
689
    ),
690
    # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
691
    dict(
692
        name="Llama-2-13b{}-hf",
693
        hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"),
694
        vocab_size=32000,
695
        padding_multiple=64,
696
        n_layer=40,
697
        n_head=40,
698
        n_embd=5120,
699
        rotary_percentage=1.0,
700
        parallel_residual=False,
701
        bias=False,
702
        _norm_class="RMSNorm",
703
        _mlp_class="LLaMAMLP",
704
        intermediate_size=13824,
705
    ),
706
    # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
707
    dict(
708
        name="Llama-2-70b{}-hf",
709
        hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"),
710
        vocab_size=32000,
711
        padding_multiple=64,
712
        n_layer=80,
713
        n_head=64,
714
        n_embd=8192,
715
        n_query_groups=8,
716
        rotary_percentage=1.0,
717
        parallel_residual=False,
718
        bias=False,
719
        _norm_class="RMSNorm",
720
        _mlp_class="LLaMAMLP",
721
        intermediate_size=28672,
722
    ),
723
]
724
for c in llama_2:
725
    for kind in ("", "-chat"):
726
        copy = deepcopy(c)
727
        copy["name"] = c["name"].format(kind)
728
        copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
729
        configs.append(copy)
730
731
732
##########################
733
# Stability AI FreeWilly2
734
##########################
735
freewilly_2 = [
736
    # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
737
    dict(
738
        name="FreeWilly2",
739
        hf_config=dict(org="stabilityai", name="FreeWilly2"),
740
        vocab_size=32000,
741
        padding_multiple=64,
742
        n_layer=80,
743
        n_head=64,
744
        n_embd=8192,
745
        n_query_groups=8,
746
        rotary_percentage=1.0,
747
        parallel_residual=False,
748
        bias=False,
749
        _norm_class="RMSNorm",
750
        _mlp_class="LLaMAMLP",
751
        intermediate_size=28672,
752
    )
753
]
754
configs.extend(freewilly_2)
755
756
757
##################
758
# Meta Code Llama
759
##################
760
code_llama = [
761
    # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json
762
    dict(
763
        name="CodeLlama-7b-hf",
764
        hf_config=dict(org="codellama", name="CodeLlama-7b-hf"),
765
        block_size=16384,
766
        vocab_size=32016,
767
        padding_multiple=16,
768
        n_layer=32,
769
        rotary_percentage=1.0,
770
        parallel_residual=False,
771
        bias=False,
772
        _norm_class="RMSNorm",
773
        norm_eps=1e-05,
774
        _mlp_class="LLaMAMLP",
775
        intermediate_size=11008,
776
        rope_base=1000000,
777
    ),
778
    # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json
779
    dict(
780
        name="CodeLlama-13b-hf",
781
        hf_config=dict(org="codellama", name="CodeLlama-13b-hf"),
782
        block_size=16384,
783
        vocab_size=32016,
784
        padding_multiple=16,
785
        n_layer=40,
786
        n_head=40,
787
        n_embd=5120,
788
        rotary_percentage=1.0,
789
        parallel_residual=False,
790
        bias=False,
791
        _norm_class="RMSNorm",
792
        norm_eps=1e-05,
793
        _mlp_class="LLaMAMLP",
794
        intermediate_size=13824,
795
        rope_base=1000000,
796
    ),
797
    # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json
798
    dict(
799
        name="CodeLlama-34b-hf",
800
        hf_config=dict(org="codellama", name="CodeLlama-34b-hf"),
801
        block_size=16384,
802
        vocab_size=32000,
803
        padding_multiple=64,
804
        n_layer=48,
805
        n_head=64,
806
        n_embd=8192,
807
        n_query_groups=8,
808
        rotary_percentage=1.0,
809
        parallel_residual=False,
810
        bias=False,
811
        _norm_class="RMSNorm",
812
        norm_eps=1e-05,
813
        _mlp_class="LLaMAMLP",
814
        intermediate_size=22016,
815
        rope_base=1000000,
816
    ),
817
    # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json
818
    dict(
819
        name="CodeLlama-7b-Python-hf",
820
        hf_config=dict(org="codellama", name="CodeLlama-7b-Python-hf"),
821
        block_size=16384,
822
        vocab_size=32000,
823
        padding_multiple=64,
824
        n_layer=32,
825
        rotary_percentage=1.0,
826
        parallel_residual=False,
827
        bias=False,
828
        _norm_class="RMSNorm",
829
        norm_eps=1e-05,
830
        _mlp_class="LLaMAMLP",
831
        intermediate_size=11008,
832
        rope_base=1000000,
833
    ),
834
    # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json
835
    dict(
836
        name="CodeLlama-13b-Python-hf",
837
        hf_config=dict(org="codellama", name="CodeLlama-13b-Python-hf"),
838
        block_size=16384,
839
        vocab_size=32000,
840
        padding_multiple=64,
841
        n_layer=40,
842
        n_head=40,
843
        n_embd=5120,
844
        rotary_percentage=1.0,
845
        parallel_residual=False,
846
        bias=False,
847
        _norm_class="RMSNorm",
848
        norm_eps=1e-05,
849
        _mlp_class="LLaMAMLP",
850
        intermediate_size=13824,
851
        rope_base=1000000,
852
    ),
853
    # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json
854
    dict(
855
        name="CodeLlama-34b-Python-hf",
856
        hf_config=dict(org="codellama", name="CodeLlama-34b-Python-hf"),
857
        block_size=16384,
858
        vocab_size=32000,
859
        padding_multiple=64,
860
        n_layer=48,
861
        n_head=64,
862
        n_embd=8192,
863
        n_query_groups=8,
864
        rotary_percentage=1.0,
865
        parallel_residual=False,
866
        bias=False,
867
        _norm_class="RMSNorm",
868
        norm_eps=1e-05,
869
        _mlp_class="LLaMAMLP",
870
        intermediate_size=22016,
871
        rope_base=1000000,
872
    ),
873
    # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/tree/main/config.json
874
    dict(
875
        name="CodeLlama-7b-Instruct-hf",
876
        hf_config=dict(org="codellama", name="CodeLlama-7b-Instruct-hf"),
877
        block_size=16384,
878
        vocab_size=32016,
879
        padding_multiple=16,
880
        n_layer=32,
881
        rotary_percentage=1.0,
882
        parallel_residual=False,
883
        bias=False,
884
        _norm_class="RMSNorm",
885
        norm_eps=1e-05,
886
        _mlp_class="LLaMAMLP",
887
        intermediate_size=11008,
888
        rope_base=1000000,
889
    ),
890
    # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json
891
    dict(
892
        name="CodeLlama-13b-Instruct-hf",
893
        hf_config=dict(org="codellama", name="CodeLlama-13b-Instruct-hf"),
894
        block_size=2048,
895
        vocab_size=32016,
896
        padding_multiple=16,
897
        n_layer=40,
898
        n_head=40,
899
        n_embd=5120,
900
        rotary_percentage=1.0,
901
        parallel_residual=False,
902
        bias=False,
903
        _norm_class="RMSNorm",
904
        norm_eps=1e-05,
905
        _mlp_class="LLaMAMLP",
906
        intermediate_size=13824,
907
        rope_base=1000000,
908
    ),
909
    # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json
910
    dict(
911
        name="CodeLlama-34b-Instruct-hf",
912
        hf_config=dict(org="codellama", name="CodeLlama-34b-Instruct-hf"),
913
        block_size=16384,
914
        vocab_size=32000,
915
        padding_multiple=64,
916
        n_layer=48,
917
        n_head=64,
918
        n_embd=8192,
919
        n_query_groups=8,
920
        rotary_percentage=1.0,
921
        parallel_residual=False,
922
        bias=False,
923
        _norm_class="RMSNorm",
924
        norm_eps=1e-05,
925
        _mlp_class="LLaMAMLP",
926
        intermediate_size=22016,
927
        rope_base=1000000,
928
    ),
929
]
930
configs.extend(code_llama)
931
932
933
########################
934
# garage-bAInd Platypus
935
########################
936
platypus = [
937
    # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json
938
    dict(
939
        name="Platypus-30B",
940
        hf_config=dict(org="garage-bAInd", name="Platypus-30B"),
941
        block_size=2048,
942
        padded_vocab_size=32000,
943
        n_layer=60,
944
        n_head=52,
945
        n_embd=6656,
946
        rotary_percentage=1.0,
947
        parallel_residual=False,
948
        bias=False,
949
        _norm_class="RMSNorm",
950
        norm_eps=1e-06,
951
        _mlp_class="LLaMAMLP",
952
        intermediate_size=17920,
953
    ),
954
    # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json
955
    dict(
956
        name="Platypus2-7B",
957
        hf_config=dict(org="garage-bAInd", name="Platypus2-7B"),
958
        padded_vocab_size=32000,
959
        n_layer=32,
960
        rotary_percentage=1.0,
961
        parallel_residual=False,
962
        bias=False,
963
        _norm_class="RMSNorm",
964
        norm_eps=1e-05,
965
        _mlp_class="LLaMAMLP",
966
        intermediate_size=11008,
967
    ),
968
    # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json
969
    dict(
970
        name="Platypus2-13B",
971
        hf_config=dict(org="garage-bAInd", name="Platypus2-13B"),
972
        padded_vocab_size=32000,
973
        n_layer=40,
974
        n_head=40,
975
        n_embd=5120,
976
        rotary_percentage=1.0,
977
        parallel_residual=False,
978
        bias=False,
979
        _norm_class="RMSNorm",
980
        norm_eps=1e-05,
981
        _mlp_class="LLaMAMLP",
982
        intermediate_size=13824,
983
    ),
984
    # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json
985
    dict(
986
        name="Platypus2-70B",
987
        hf_config=dict(org="garage-bAInd", name="Platypus2-70B"),
988
        padded_vocab_size=32000,
989
        n_layer=80,
990
        n_head=64,
991
        n_embd=8192,
992
        rotary_percentage=1.0,
993
        parallel_residual=False,
994
        bias=False,
995
        _norm_class="RMSNorm",
996
        _mlp_class="LLaMAMLP",
997
        intermediate_size=28672,
998
    ),
999
    # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json
1000
    dict(
1001
        name="Camel-Platypus2-13B",
1002
        hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-13B"),
1003
        padded_vocab_size=32000,
1004
        n_layer=40,
1005
        n_head=40,
1006
        n_embd=5120,
1007
        rotary_percentage=1.0,
1008
        parallel_residual=False,
1009
        bias=False,
1010
        _norm_class="RMSNorm",
1011
        _mlp_class="LLaMAMLP",
1012
        intermediate_size=13824,
1013
    ),
1014
    # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json
1015
    dict(
1016
        name="Camel-Platypus2-70B",
1017
        hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-70B"),
1018
        padded_vocab_size=32000,
1019
        n_layer=80,
1020
        n_head=64,
1021
        n_embd=8192,
1022
        n_query_groups=8,
1023
        rotary_percentage=1.0,
1024
        parallel_residual=False,
1025
        bias=False,
1026
        _norm_class="RMSNorm",
1027
        _mlp_class="LLaMAMLP",
1028
        intermediate_size=28672,
1029
    ),
1030
    # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json
1031
    dict(
1032
        name="Stable-Platypus2-13B",
1033
        hf_config=dict(org="garage-bAInd", name="Stable-Platypus2-13B"),
1034
        padded_vocab_size=32000,
1035
        n_layer=40,
1036
        n_head=40,
1037
        n_embd=5120,
1038
        rotary_percentage=1.0,
1039
        parallel_residual=False,
1040
        bias=False,
1041
        _norm_class="RMSNorm",
1042
        _mlp_class="LLaMAMLP",
1043
        intermediate_size=13824,
1044
    ),
1045
    # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json
1046
    dict(
1047
        name="Platypus2-70B-instruct",
1048
        hf_config=dict(org="garage-bAInd", name="Platypus2-70B-instruct"),
1049
        padded_vocab_size=32000,
1050
        n_layer=80,
1051
        n_head=64,
1052
        n_embd=8192,
1053
        n_query_groups=8,
1054
        rotary_percentage=1.0,
1055
        parallel_residual=False,
1056
        bias=False,
1057
        _norm_class="RMSNorm",
1058
        _mlp_class="LLaMAMLP",
1059
        intermediate_size=28672,
1060
    ),
1061
]
1062
configs.extend(platypus)
1063
1064
1065
##########################
1066
# Stability AI StableCode
1067
##########################
1068
stablecode = [
1069
    # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json
1070
    dict(
1071
        name="stablecode-completion-alpha-3b",
1072
        hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b"),
1073
        block_size=16384,
1074
        vocab_size=49152,
1075
        n_layer=32,
1076
        n_embd=2560,
1077
    ),
1078
    # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json
1079
    dict(
1080
        name="stablecode-completion-alpha-3b-4k",
1081
        hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k"),
1082
        vocab_size=49152,
1083
        n_layer=32,
1084
        n_embd=2560,
1085
    ),
1086
    # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json
1087
    dict(
1088
        name="stablecode-instruct-alpha-3b",
1089
        hf_config=dict(org="stabilityai", name="stablecode-instruct-alpha-3b"),
1090
        vocab_size=49152,
1091
        n_layer=32,
1092
        n_embd=2560,
1093
    ),
1094
]
1095
configs.extend(stablecode)
1096
1097
1098
##################################
1099
# togethercomputer LLaMA-2-7B-32K
1100
##################################
1101
together_llama2_32k = [
1102
    # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json
1103
    dict(
1104
        name="LLaMA-2-7B-32K",
1105
        hf_config=dict(org="togethercomputer", name="LLaMA-2-7B-32K"),
1106
        vocab_size=32000,
1107
        padding_multiple=64,
1108
        n_layer=32,
1109
        rotary_percentage=1.0,
1110
        parallel_residual=False,
1111
        bias=False,
1112
        _norm_class="RMSNorm",
1113
        _mlp_class="LLaMAMLP",
1114
        intermediate_size=11008,
1115
        rope_condense_ratio=8,
1116
    )
1117
]
1118
configs.extend(together_llama2_32k)
1119
1120
1121
################
1122
# Microsoft Phi
1123
################
1124
phi = [
1125
    # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json
1126
    dict(
1127
        name="phi-1_5",
1128
        hf_config=dict(org="microsoft", name="phi-1_5"),
1129
        vocab_size=50257,
1130
        padded_vocab_size=51200,
1131
        block_size=2048,
1132
        n_embd=2048,
1133
        n_layer=24,
1134
        rotary_percentage=0.5,  # 32 / (n_embd / n_head) = 32 / 64
1135
        shared_attention_norm=True,
1136
        lm_head_bias=True,
1137
        gelu_approximate="tanh",
1138
    )
1139
]
1140
configs.extend(phi)
1141
1142
1143
#############
1144
# Mistral AI
1145
#############
1146
mistral = [
1147
    # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
1148
    dict(
1149
        name="Mistral-7B-{}v0.1",
1150
        hf_config=dict(org="mistralai", name="Mistral-7B-{}v0.1"),
1151
        padded_vocab_size=32000,
1152
        block_size=4096,  # should be 32768 but sliding window attention is not implemented
1153
        n_layer=32,
1154
        n_query_groups=8,
1155
        rotary_percentage=1.0,
1156
        parallel_residual=False,
1157
        bias=False,
1158
        _norm_class="RMSNorm",
1159
        norm_eps=1e-05,
1160
        _mlp_class="LLaMAMLP",
1161
        intermediate_size=14336,
1162
    )
1163
]
1164
for c in mistral:
1165
    for kind in ("", "Instruct-"):
1166
        copy = deepcopy(c)
1167
        copy["name"] = c["name"].format(kind)
1168
        copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
1169
        configs.append(copy)
1170
1171
1172
############
1173
# TinyLlama
1174
############
1175
tiny_llama = [
1176
    dict(
1177
        name="tiny-llama-1.1b{}",
1178
        hf_config=dict(org="TinyLlama", name="TinyLlama-1.1B{}"),
1179
        block_size=2048,
1180
        vocab_size=32000,
1181
        padding_multiple=64,
1182
        n_layer=22,
1183
        n_head=32,
1184
        n_embd=2048,
1185
        rotary_percentage=1.0,
1186
        parallel_residual=False,
1187
        bias=False,
1188
        _norm_class="RMSNorm",  # original TinyLlama uses FusedRMSNorm
1189
        norm_eps=1e-5,
1190
        _mlp_class="LLaMAMLP",
1191
        intermediate_size=5632,
1192
        n_query_groups=4,
1193
    ),
1194
]
1195
for c in tiny_llama:
1196
    for kind, hf_postfix in (("", "-intermediate-step-955k-token-2T"), ("chat", "-Chat-v0.6")):
1197
        copy = deepcopy(c)
1198
        copy["name"] = c["name"].format(kind)
1199
        copy["hf_config"]["name"] = c["hf_config"]["name"].format(hf_postfix)
1200
        configs.append(copy)
1201
1202
1203
name_to_config = {config["name"]: config for config in configs}