[7554f1]: / myenv / Lib / site-packages / threadpoolctl.py

Download this file

1080 lines (870 with data), 41.6 kB

   1
   2
   3
   4
   5
   6
   7
   8
   9
  10
  11
  12
  13
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
"""threadpoolctl
This module provides utilities to introspect native libraries that relies on
thread pools (notably BLAS and OpenMP implementations) and dynamically set the
maximal number of threads they can use.
"""
# License: BSD 3-Clause
# The code to introspect dynamically loaded libraries on POSIX systems is
# adapted from code by Intel developer @anton-malakhov available at
# https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation)
# and also published under the BSD 3-Clause license
import os
import re
import sys
import ctypes
import textwrap
from typing import final
import warnings
from ctypes.util import find_library
from abc import ABC, abstractmethod
from functools import lru_cache
from contextlib import ContextDecorator
__version__ = "3.2.0"
__all__ = [
"threadpool_limits",
"threadpool_info",
"ThreadpoolController",
"LibController",
"register",
]
# One can get runtime errors or even segfaults due to multiple OpenMP libraries
# loaded simultaneously which can happen easily in Python when importing and
# using compiled extensions built with different compilers and therefore
# different OpenMP runtimes in the same program. In particular libiomp (used by
# Intel ICC) and libomp used by clang/llvm tend to crash. This can happen for
# instance when calling BLAS inside a prange. Setting the following environment
# variable allows multiple OpenMP libraries to be loaded. It should not degrade
# performances since we manually take care of potential over-subscription
# performance issues, in sections of the code where nested OpenMP loops can
# happen, by dynamically reconfiguring the inner OpenMP runtime to temporarily
# disable it while under the scope of the outer OpenMP parallel section.
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True")
# Structure to cast the info on dynamically loaded library. See
# https://linux.die.net/man/3/dl_iterate_phdr for more details.
_SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32
_SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16
class _dl_phdr_info(ctypes.Structure):
_fields_ = [
("dlpi_addr", _SYSTEM_UINT), # Base address of object
("dlpi_name", ctypes.c_char_p), # path to the library
("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers
("dlpi_phnum", _SYSTEM_UINT_HALF), # number of elements in dlpi_phdr
]
# The RTLD_NOLOAD flag for loading shared libraries is not defined on Windows.
try:
_RTLD_NOLOAD = os.RTLD_NOLOAD
except AttributeError:
_RTLD_NOLOAD = ctypes.DEFAULT_MODE
class LibController(ABC):
"""Abstract base class for the individual library controllers
A library controller must expose the following class attributes:
- user_api : str
Usually the name of the library or generic specification the library
implements, e.g. "blas" is a specification with different implementations.
- internal_api : str
Usually the name of the library or concrete implementation of some
specification, e.g. "openblas" is an implementation of the "blas"
specification.
- filename_prefixes : tuple
Possible prefixes of the shared library's filename that allow to
identify the library. e.g. "libopenblas" for libopenblas.so.
and implement the following methods: `get_num_threads`, `set_num_threads` and
`get_version`.
Threadpoolctl loops through all the loaded shared libraries and tries to match
the filename of each library with the `filename_prefixes`. If a match is found, a
controller is instantiated and a handler to the library is stored in the `dynlib`
attribute as a `ctypes.CDLL` object. It can be used to access the necessary symbols
of the shared library to implement the above methods.
The following information will be exposed in the info dictionary:
- user_api : standardized API, if any, or a copy of internal_api.
- internal_api : implementation-specific API.
- num_threads : the current thread limit.
- prefix : prefix of the shared library's filename.
- filepath : path to the loaded shared library.
- version : version of the library (if available).
In addition, each library controller may expose internal API specific entries. They
must be set as attributes in the `set_additional_attributes` method.
"""
@final
def __init__(self, *, filepath=None, prefix=None):
"""This is not meant to be overriden by subclasses."""
self.prefix = prefix
self.filepath = filepath
self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
self.version = self.get_version()
self.set_additional_attributes()
@final
def info(self):
"""Return relevant info wrapped in a dict
This is not meant to be overriden by subclasses.
"""
exposed_attrs = {
"user_api": self.user_api,
"internal_api": self.internal_api,
"num_threads": self.num_threads,
**vars(self),
}
exposed_attrs.pop("dynlib")
return exposed_attrs
def set_additional_attributes(self):
"""Set additional attributes meant to be exposed in the info dict"""
@property
def num_threads(self):
"""Exposes the current thread limit as a dynamic property
This is not meant to be used or overriden by subclasses.
"""
return self.get_num_threads()
@abstractmethod
def get_num_threads(self):
"""Return the maximum number of threads available to use"""
@abstractmethod
def set_num_threads(self, num_threads):
"""Set the maximum number of threads to use"""
@abstractmethod
def get_version(self):
"""Return the version of the shared library"""
class OpenBLASController(LibController):
"""Controller class for OpenBLAS"""
user_api = "blas"
internal_api = "openblas"
filename_prefixes = ("libopenblas", "libblas")
check_symbols = ("openblas_get_num_threads", "openblas_get_num_threads64_")
def set_additional_attributes(self):
self.threading_layer = self._get_threading_layer()
self.architecture = self._get_architecture()
def get_num_threads(self):
get_func = getattr(
self.dynlib,
"openblas_get_num_threads",
# Symbols differ when built for 64bit integers in Fortran
getattr(self.dynlib, "openblas_get_num_threads64_", lambda: None),
)
return get_func()
def set_num_threads(self, num_threads):
set_func = getattr(
self.dynlib,
"openblas_set_num_threads",
# Symbols differ when built for 64bit integers in Fortran
getattr(
self.dynlib, "openblas_set_num_threads64_", lambda num_threads: None
),
)
return set_func(num_threads)
def get_version(self):
# None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
# did not expose its version before that.
get_config = getattr(
self.dynlib,
"openblas_get_config",
getattr(self.dynlib, "openblas_get_config64_", None),
)
if get_config is None:
return None
get_config.restype = ctypes.c_char_p
config = get_config().split()
if config[0] == b"OpenBLAS":
return config[1].decode("utf-8")
return None
def _get_threading_layer(self):
"""Return the threading layer of OpenBLAS"""
openblas_get_parallel = getattr(
self.dynlib,
"openblas_get_parallel",
getattr(self.dynlib, "openblas_get_parallel64_", None),
)
if openblas_get_parallel is None:
return "unknown"
threading_layer = openblas_get_parallel()
if threading_layer == 2:
return "openmp"
elif threading_layer == 1:
return "pthreads"
return "disabled"
def _get_architecture(self):
"""Return the architecture detected by OpenBLAS"""
get_corename = getattr(
self.dynlib,
"openblas_get_corename",
getattr(self.dynlib, "openblas_get_corename64_", None),
)
if get_corename is None:
return None
get_corename.restype = ctypes.c_char_p
return get_corename().decode("utf-8")
class BLISController(LibController):
"""Controller class for BLIS"""
user_api = "blas"
internal_api = "blis"
filename_prefixes = ("libblis", "libblas")
check_symbols = ("bli_thread_get_num_threads",)
def set_additional_attributes(self):
self.threading_layer = self._get_threading_layer()
self.architecture = self._get_architecture()
def get_num_threads(self):
get_func = getattr(self.dynlib, "bli_thread_get_num_threads", lambda: None)
num_threads = get_func()
# by default BLIS is single-threaded and get_num_threads
# returns -1. We map it to 1 for consistency with other libraries.
return 1 if num_threads == -1 else num_threads
def set_num_threads(self, num_threads):
set_func = getattr(
self.dynlib, "bli_thread_set_num_threads", lambda num_threads: None
)
return set_func(num_threads)
def get_version(self):
get_version_ = getattr(self.dynlib, "bli_info_get_version_str", None)
if get_version_ is None:
return None
get_version_.restype = ctypes.c_char_p
return get_version_().decode("utf-8")
def _get_threading_layer(self):
"""Return the threading layer of BLIS"""
if self.dynlib.bli_info_get_enable_openmp():
return "openmp"
elif self.dynlib.bli_info_get_enable_pthreads():
return "pthreads"
return "disabled"
def _get_architecture(self):
"""Return the architecture detected by BLIS"""
bli_arch_query_id = getattr(self.dynlib, "bli_arch_query_id", None)
bli_arch_string = getattr(self.dynlib, "bli_arch_string", None)
if bli_arch_query_id is None or bli_arch_string is None:
return None
# the true restype should be BLIS' arch_t (enum) but int should work
# for us:
bli_arch_query_id.restype = ctypes.c_int
bli_arch_string.restype = ctypes.c_char_p
return bli_arch_string(bli_arch_query_id()).decode("utf-8")
class MKLController(LibController):
"""Controller class for MKL"""
user_api = "blas"
internal_api = "mkl"
filename_prefixes = ("libmkl_rt", "mkl_rt", "libblas")
check_symbols = ("MKL_Get_Max_Threads",)
def set_additional_attributes(self):
self.threading_layer = self._get_threading_layer()
def get_num_threads(self):
get_func = getattr(self.dynlib, "MKL_Get_Max_Threads", lambda: None)
return get_func()
def set_num_threads(self, num_threads):
set_func = getattr(self.dynlib, "MKL_Set_Num_Threads", lambda num_threads: None)
return set_func(num_threads)
def get_version(self):
if not hasattr(self.dynlib, "MKL_Get_Version_String"):
return None
res = ctypes.create_string_buffer(200)
self.dynlib.MKL_Get_Version_String(res, 200)
version = res.value.decode("utf-8")
group = re.search(r"Version ([^ ]+) ", version)
if group is not None:
version = group.groups()[0]
return version.strip()
def _get_threading_layer(self):
"""Return the threading layer of MKL"""
# The function mkl_set_threading_layer returns the current threading
# layer. Calling it with an invalid threading layer allows us to safely
# get the threading layer
set_threading_layer = getattr(
self.dynlib, "MKL_Set_Threading_Layer", lambda layer: -1
)
layer_map = {
0: "intel",
1: "sequential",
2: "pgi",
3: "gnu",
4: "tbb",
-1: "not specified",
}
return layer_map[set_threading_layer(-1)]
class OpenMPController(LibController):
"""Controller class for OpenMP"""
user_api = "openmp"
internal_api = "openmp"
filename_prefixes = ("libiomp", "libgomp", "libomp", "vcomp")
def get_num_threads(self):
get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None)
return get_func()
def set_num_threads(self, num_threads):
set_func = getattr(self.dynlib, "omp_set_num_threads", lambda num_threads: None)
return set_func(num_threads)
def get_version(self):
# There is no way to get the version number programmatically in OpenMP.
return None
# Controllers for the libraries that we'll look for in the loaded libraries.
# Third party libraries can register their own controllers.
_ALL_CONTROLLERS = [OpenBLASController, BLISController, MKLController, OpenMPController]
# Helpers for the doc and test names
_ALL_USER_APIS = list(set(lib.user_api for lib in _ALL_CONTROLLERS))
_ALL_INTERNAL_APIS = [lib.internal_api for lib in _ALL_CONTROLLERS]
_ALL_PREFIXES = list(
set(prefix for lib in _ALL_CONTROLLERS for prefix in lib.filename_prefixes)
)
_ALL_BLAS_LIBRARIES = [
lib.internal_api for lib in _ALL_CONTROLLERS if lib.user_api == "blas"
]
_ALL_OPENMP_LIBRARIES = OpenMPController.filename_prefixes
def register(controller):
"""Register a new controller"""
_ALL_CONTROLLERS.append(controller)
_ALL_USER_APIS.append(controller.user_api)
_ALL_INTERNAL_APIS.append(controller.internal_api)
_ALL_PREFIXES.extend(controller.filename_prefixes)
def _format_docstring(*args, **kwargs):
def decorator(o):
if o.__doc__ is not None:
o.__doc__ = o.__doc__.format(*args, **kwargs)
return o
return decorator
@lru_cache(maxsize=10000)
def _realpath(filepath):
"""Small caching wrapper around os.path.realpath to limit system calls"""
return os.path.realpath(filepath)
@_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS)
def threadpool_info():
"""Return the maximal number of threads for each detected library.
Return a list with all the supported libraries that have been found. Each
library is represented by a dict with the following information:
- "user_api" : user API. Possible values are {USER_APIS}.
- "internal_api": internal API. Possible values are {INTERNAL_APIS}.
- "prefix" : filename prefix of the specific implementation.
- "filepath": path to the loaded library.
- "version": version of the library (if available).
- "num_threads": the current thread limit.
In addition, each library may contain internal_api specific entries.
"""
return ThreadpoolController().info()
class _ThreadpoolLimiter:
"""The guts of ThreadpoolController.limit
Refer to the docstring of ThreadpoolController.limit for more details.
It will only act on the library controllers held by the provided `controller`.
Using the default constructor sets the limits right away such that it can be used as
a callable. Setting the limits can be delayed by using the `wrap` class method such
that it can be used as a decorator.
"""
def __init__(self, controller, *, limits=None, user_api=None):
self._controller = controller
self._limits, self._user_api, self._prefixes = self._check_params(
limits, user_api
)
self._original_info = self._controller.info()
self._set_threadpool_limits()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.restore_original_limits()
@classmethod
def wrap(cls, controller, *, limits=None, user_api=None):
"""Return an instance of this class that can be used as a decorator"""
return _ThreadpoolLimiterDecorator(
controller=controller, limits=limits, user_api=user_api
)
def restore_original_limits(self):
"""Set the limits back to their original values"""
for lib_controller, original_info in zip(
self._controller.lib_controllers, self._original_info
):
lib_controller.set_num_threads(original_info["num_threads"])
# Alias of `restore_original_limits` for backward compatibility
unregister = restore_original_limits
def get_original_num_threads(self):
"""Original num_threads from before calling threadpool_limits
Return a dict `{user_api: num_threads}`.
"""
num_threads = {}
warning_apis = []
for user_api in self._user_api:
limits = [
lib_info["num_threads"]
for lib_info in self._original_info
if lib_info["user_api"] == user_api
]
limits = set(limits)
n_limits = len(limits)
if n_limits == 1:
limit = limits.pop()
elif n_limits == 0:
limit = None
else:
limit = min(limits)
warning_apis.append(user_api)
num_threads[user_api] = limit
if warning_apis:
warnings.warn(
"Multiple value possible for following user apis: "
+ ", ".join(warning_apis)
+ ". Returning the minimum."
)
return num_threads
def _check_params(self, limits, user_api):
"""Suitable values for the _limits, _user_api and _prefixes attributes"""
if isinstance(limits, str) and limits == "sequential_blas_under_openmp":
(
limits,
user_api,
) = self._controller._get_params_for_sequential_blas_under_openmp().values()
if limits is None or isinstance(limits, int):
if user_api is None:
user_api = _ALL_USER_APIS
elif user_api in _ALL_USER_APIS:
user_api = [user_api]
else:
raise ValueError(
f"user_api must be either in {_ALL_USER_APIS} or None. Got "
f"{user_api} instead."
)
if limits is not None:
limits = {api: limits for api in user_api}
prefixes = []
else:
if isinstance(limits, list):
# This should be a list of dicts of library info, for
# compatibility with the result from threadpool_info.
limits = {
lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits
}
elif isinstance(limits, ThreadpoolController):
# To set the limits from the library controllers of a
# ThreadpoolController object.
limits = {
lib_controller.prefix: lib_controller.num_threads
for lib_controller in limits.lib_controllers
}
if not isinstance(limits, dict):
raise TypeError(
"limits must either be an int, a list, a dict, or "
f"'sequential_blas_under_openmp'. Got {type(limits)} instead"
)
# With a dictionary, can set both specific limit for given
# libraries and global limit for user_api. Fetch each separately.
prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES]
user_api = [api for api in limits if api in _ALL_USER_APIS]
return limits, user_api, prefixes
def _set_threadpool_limits(self):
"""Change the maximal number of threads in selected thread pools.
Return a list with all the supported libraries that have been found
matching `self._prefixes` and `self._user_api`.
"""
if self._limits is None:
return
for lib_controller in self._controller.lib_controllers:
# self._limits is a dict {key: num_threads} where key is either
# a prefix or a user_api. If a library matches both, the limit
# corresponding to the prefix is chosen.
if lib_controller.prefix in self._limits:
num_threads = self._limits[lib_controller.prefix]
elif lib_controller.user_api in self._limits:
num_threads = self._limits[lib_controller.user_api]
else:
continue
if num_threads is not None:
lib_controller.set_num_threads(num_threads)
class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator):
"""Same as _ThreadpoolLimiter but to be used as a decorator"""
def __init__(self, controller, *, limits=None, user_api=None):
self._limits, self._user_api, self._prefixes = self._check_params(
limits, user_api
)
self._controller = controller
def __enter__(self):
# we need to set the limits here and not in the __init__ because we want the
# limits to be set when calling the decorated function, not when creating the
# decorator.
self._original_info = self._controller.info()
self._set_threadpool_limits()
return self
@_format_docstring(
USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
)
class threadpool_limits(_ThreadpoolLimiter):
"""Change the maximal number of threads that can be used in thread pools.
This object can be used either as a callable (the construction of this object
limits the number of threads), as a context manager in a `with` block to
automatically restore the original state of the controlled libraries when exiting
the block, or as a decorator through its `wrap` method.
Set the maximal number of threads that can be used in thread pools used in
the supported libraries to `limit`. This function works for libraries that
are already loaded in the interpreter and can be changed dynamically.
This effect is global and impacts the whole Python process. There is no thread level
isolation as these libraries do not offer thread-local APIs to configure the number
of threads to use in nested parallel calls.
Parameters
----------
limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
The maximal number of threads that can be used in thread pools
- If int, sets the maximum number of threads to `limits` for each
library selected by `user_api`.
- If it is a dictionary `{{key: max_threads}}`, this function sets a
custom maximum number of threads for each `key` which can be either a
`user_api` or a `prefix` for a specific library.
- If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
and `user_api` parameters for the specific use case of sequential BLAS
calls within an OpenMP parallel region. The `user_api` parameter is
ignored.
- If None, this function does not do anything.
user_api : {USER_APIS} or None (default=None)
APIs of libraries to limit. Used only if `limits` is an int.
- If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
- If "openmp", it will only limit OpenMP supported libraries
({OPENMP_LIBS}). Note that it can affect the number of threads used
by the BLAS libraries if they rely on OpenMP.
- If None, this function will apply to all supported libraries.
"""
def __init__(self, limits=None, user_api=None):
super().__init__(ThreadpoolController(), limits=limits, user_api=user_api)
@classmethod
def wrap(cls, limits=None, user_api=None):
return super().wrap(ThreadpoolController(), limits=limits, user_api=user_api)
class ThreadpoolController:
"""Collection of LibController objects for all loaded supported libraries
Attributes
----------
lib_controllers : list of `LibController` objects
The list of library controllers of all loaded supported libraries.
"""
# Cache for libc under POSIX and a few system libraries under Windows.
# We use a class level cache instead of an instance level cache because
# it's very unlikely that a shared library will be unloaded and reloaded
# during the lifetime of a program.
_system_libraries = dict()
def __init__(self):
self.lib_controllers = []
self._load_libraries()
self._warn_if_incompatible_openmp()
@classmethod
def _from_controllers(cls, lib_controllers):
new_controller = cls.__new__(cls)
new_controller.lib_controllers = lib_controllers
return new_controller
def info(self):
"""Return lib_controllers info as a list of dicts"""
return [lib_controller.info() for lib_controller in self.lib_controllers]
def select(self, **kwargs):
"""Return a ThreadpoolController containing a subset of its current
library controllers
It will select all libraries matching at least one pair (key, value) from kwargs
where key is an entry of the library info dict (like "user_api", "internal_api",
"prefix", ...) and value is the value or a list of acceptable values for that
entry.
For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])`
will select all library controllers whose internal_api is either "blis" or
"openblas".
"""
for key, vals in kwargs.items():
kwargs[key] = [vals] if not isinstance(vals, list) else vals
lib_controllers = [
lib_controller
for lib_controller in self.lib_controllers
if any(
getattr(lib_controller, key, None) in vals
for key, vals in kwargs.items()
)
]
return ThreadpoolController._from_controllers(lib_controllers)
def _get_params_for_sequential_blas_under_openmp(self):
"""Return appropriate params to use for a sequential BLAS call in an OpenMP loop
This function takes into account the unexpected behavior of OpenBLAS with the
OpenMP threading layer.
"""
if self.select(
internal_api="openblas", threading_layer="openmp"
).lib_controllers:
return {"limits": None, "user_api": None}
return {"limits": 1, "user_api": "blas"}
@_format_docstring(
USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
)
def limit(self, *, limits=None, user_api=None):
"""Change the maximal number of threads that can be used in thread pools.
This function returns an object that can be used either as a callable (the
construction of this object limits the number of threads) or as a context
manager, in a `with` block to automatically restore the original state of the
controlled libraries when exiting the block.
Set the maximal number of threads that can be used in thread pools used in
the supported libraries to `limits`. This function works for libraries that
are already loaded in the interpreter and can be changed dynamically.
This effect is global and impacts the whole Python process. There is no thread
level isolation as these libraries do not offer thread-local APIs to configure
the number of threads to use in nested parallel calls.
Parameters
----------
limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
The maximal number of threads that can be used in thread pools
- If int, sets the maximum number of threads to `limits` for each
library selected by `user_api`.
- If it is a dictionary `{{key: max_threads}}`, this function sets a
custom maximum number of threads for each `key` which can be either a
`user_api` or a `prefix` for a specific library.
- If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
and `user_api` parameters for the specific use case of sequential BLAS
calls within an OpenMP parallel region. The `user_api` parameter is
ignored.
- If None, this function does not do anything.
user_api : {USER_APIS} or None (default=None)
APIs of libraries to limit. Used only if `limits` is an int.
- If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
- If "openmp", it will only limit OpenMP supported libraries
({OPENMP_LIBS}). Note that it can affect the number of threads used
by the BLAS libraries if they rely on OpenMP.
- If None, this function will apply to all supported libraries.
"""
return _ThreadpoolLimiter(self, limits=limits, user_api=user_api)
@_format_docstring(
USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
)
def wrap(self, *, limits=None, user_api=None):
"""Change the maximal number of threads that can be used in thread pools.
This function returns an object that can be used as a decorator.
Set the maximal number of threads that can be used in thread pools used in
the supported libraries to `limits`. This function works for libraries that
are already loaded in the interpreter and can be changed dynamically.
Parameters
----------
limits : int, dict or None (default=None)
The maximal number of threads that can be used in thread pools
- If int, sets the maximum number of threads to `limits` for each
library selected by `user_api`.
- If it is a dictionary `{{key: max_threads}}`, this function sets a
custom maximum number of threads for each `key` which can be either a
`user_api` or a `prefix` for a specific library.
- If None, this function does not do anything.
user_api : {USER_APIS} or None (default=None)
APIs of libraries to limit. Used only if `limits` is an int.
- If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
- If "openmp", it will only limit OpenMP supported libraries
({OPENMP_LIBS}). Note that it can affect the number of threads used
by the BLAS libraries if they rely on OpenMP.
- If None, this function will apply to all supported libraries.
"""
return _ThreadpoolLimiter.wrap(self, limits=limits, user_api=user_api)
def __len__(self):
return len(self.lib_controllers)
def _load_libraries(self):
"""Loop through loaded shared libraries and store the supported ones"""
if sys.platform == "darwin":
self._find_libraries_with_dyld()
elif sys.platform == "win32":
self._find_libraries_with_enum_process_module_ex()
else:
self._find_libraries_with_dl_iterate_phdr()
def _find_libraries_with_dl_iterate_phdr(self):
"""Loop through loaded libraries and return binders on supported ones
This function is expected to work on POSIX system only.
This code is adapted from code by Intel developer @anton-malakhov
available at https://github.com/IntelPython/smp
Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause
license
"""
libc = self._get_libc()
if not hasattr(libc, "dl_iterate_phdr"): # pragma: no cover
return []
# Callback function for `dl_iterate_phdr` which is called for every
# library loaded in the current process until it returns 1.
def match_library_callback(info, size, data):
# Get the path of the current library
filepath = info.contents.dlpi_name
if filepath:
filepath = filepath.decode("utf-8")
# Store the library controller if it is supported and selected
self._make_controller_from_path(filepath)
return 0
c_func_signature = ctypes.CFUNCTYPE(
ctypes.c_int, # Return type
ctypes.POINTER(_dl_phdr_info),
ctypes.c_size_t,
ctypes.c_char_p,
)
c_match_library_callback = c_func_signature(match_library_callback)
data = ctypes.c_char_p(b"")
libc.dl_iterate_phdr(c_match_library_callback, data)
def _find_libraries_with_dyld(self):
"""Loop through loaded libraries and return binders on supported ones
This function is expected to work on OSX system only
"""
libc = self._get_libc()
if not hasattr(libc, "_dyld_image_count"): # pragma: no cover
return []
n_dyld = libc._dyld_image_count()
libc._dyld_get_image_name.restype = ctypes.c_char_p
for i in range(n_dyld):
filepath = ctypes.string_at(libc._dyld_get_image_name(i))
filepath = filepath.decode("utf-8")
# Store the library controller if it is supported and selected
self._make_controller_from_path(filepath)
def _find_libraries_with_enum_process_module_ex(self):
"""Loop through loaded libraries and return binders on supported ones
This function is expected to work on windows system only.
This code is adapted from code by Philipp Hagemeister @phihag available
at https://stackoverflow.com/questions/17474574
"""
from ctypes.wintypes import DWORD, HMODULE, MAX_PATH
PROCESS_QUERY_INFORMATION = 0x0400
PROCESS_VM_READ = 0x0010
LIST_LIBRARIES_ALL = 0x03
ps_api = self._get_windll("Psapi")
kernel_32 = self._get_windll("kernel32")
h_process = kernel_32.OpenProcess(
PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid()
)
if not h_process: # pragma: no cover
raise OSError(f"Could not open PID {os.getpid()}")
try:
buf_count = 256
needed = DWORD()
# Grow the buffer until it becomes large enough to hold all the
# module headers
while True:
buf = (HMODULE * buf_count)()
buf_size = ctypes.sizeof(buf)
if not ps_api.EnumProcessModulesEx(
h_process,
ctypes.byref(buf),
buf_size,
ctypes.byref(needed),
LIST_LIBRARIES_ALL,
):
raise OSError("EnumProcessModulesEx failed")
if buf_size >= needed.value:
break
buf_count = needed.value // (buf_size // buf_count)
count = needed.value // (buf_size // buf_count)
h_modules = map(HMODULE, buf[:count])
# Loop through all the module headers and get the library path
buf = ctypes.create_unicode_buffer(MAX_PATH)
n_size = DWORD()
for h_module in h_modules:
# Get the path of the current module
if not ps_api.GetModuleFileNameExW(
h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size)
):
raise OSError("GetModuleFileNameEx failed")
filepath = buf.value
# Store the library controller if it is supported and selected
self._make_controller_from_path(filepath)
finally:
kernel_32.CloseHandle(h_process)
def _make_controller_from_path(self, filepath):
"""Store a library controller if it is supported and selected"""
# Required to resolve symlinks
filepath = _realpath(filepath)
# `lower` required to take account of OpenMP dll case on Windows
# (vcomp, VCOMP, Vcomp, ...)
filename = os.path.basename(filepath).lower()
# Loop through supported libraries to find if this filename corresponds
# to a supported one.
for controller_class in _ALL_CONTROLLERS:
# check if filename matches a supported prefix
prefix = self._check_prefix(filename, controller_class.filename_prefixes)
# filename does not match any of the prefixes of the candidate
# library. move to next library.
if prefix is None:
continue
# workaround for BLAS libraries packaged by conda-forge on windows, which
# are all renamed "libblas.dll". We thus have to check to which BLAS
# implementation it actually corresponds looking for implementation
# specific symbols.
if prefix == "libblas":
if filename.endswith(".dll"):
libblas = ctypes.CDLL(filepath, _RTLD_NOLOAD)
if not any(
hasattr(libblas, func)
for func in controller_class.check_symbols
):
continue
else:
# We ignore libblas on other platforms than windows because there
# might be a libblas dso comming with openblas for instance that
# can't be used to instantiate a pertinent LibController (many
# symbols are missing) and would create confusion by making a
# duplicate entry in threadpool_info.
continue
# filename matches a prefix. Create and store the library
# controller.
lib_controller = controller_class(filepath=filepath, prefix=prefix)
self.lib_controllers.append(lib_controller)
def _check_prefix(self, library_basename, filename_prefixes):
"""Return the prefix library_basename starts with
Return None if none matches.
"""
for prefix in filename_prefixes:
if library_basename.startswith(prefix):
return prefix
return None
def _warn_if_incompatible_openmp(self):
"""Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded"""
prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers]
msg = textwrap.dedent("""
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md
""")
if "libomp" in prefixes and "libiomp" in prefixes:
warnings.warn(msg, RuntimeWarning)
@classmethod
def _get_libc(cls):
"""Load the lib-C for unix systems."""
libc = cls._system_libraries.get("libc")
if libc is None:
libc_name = find_library("c")
if libc_name is None: # pragma: no cover
warnings.warn(
"libc not found. The ctypes module in Python"
f" {sys.version_info.major}.{sys.version_info.minor} is maybe"
" too old for this OS.",
RuntimeWarning,
)
return None
libc = ctypes.CDLL(libc_name, mode=_RTLD_NOLOAD)
cls._system_libraries["libc"] = libc
return libc
@classmethod
def _get_windll(cls, dll_name):
"""Load a windows DLL"""
dll = cls._system_libraries.get(dll_name)
if dll is None:
dll = ctypes.WinDLL(f"{dll_name}.dll")
cls._system_libraries[dll_name] = dll
return dll
def _main():
"""Commandline interface to display thread-pool information and exit."""
import argparse
import importlib
import json
import sys
parser = argparse.ArgumentParser(
usage="python -m threadpoolctl -i numpy scipy.linalg xgboost",
description="Display thread-pool information and exit.",
)
parser.add_argument(
"-i",
"--import",
dest="modules",
nargs="*",
default=(),
help="Python modules to import before introspecting thread-pools.",
)
parser.add_argument(
"-c",
"--command",
help="a Python statement to execute before introspecting thread-pools.",
)
options = parser.parse_args(sys.argv[1:])
for module in options.modules:
try:
importlib.import_module(module, package=None)
except ImportError:
print("WARNING: could not import", module, file=sys.stderr)
if options.command:
exec(options.command)
print(json.dumps(threadpool_info(), indent=2))
if __name__ == "__main__":
_main()