--- a +++ b/unimol/data/plasma_utils.py @@ -0,0 +1,188 @@ +import hashlib +import json +import subprocess +import tempfile +from typing import Hashable + +try: + import pyarrow.plasma as plasma + + PYARROW_AVAILABLE = True +except ImportError: + plasma = None + PYARROW_AVAILABLE = False + + +class PlasmaArray: + """ + Wrapper around numpy arrays that automatically moves the data to shared + memory upon serialization. This is particularly helpful when passing numpy + arrays through multiprocessing, so that data is not unnecessarily + duplicated or pickled. + """ + + def __init__(self, array): + super().__init__() + self.array = array + self.disable = array.nbytes < 134217728 # disable for arrays <128MB + self.object_id = None + self.path = None + + # variables with underscores shouldn't be pickled + self._client = None + self._server = None + self._server_tmp = None + self._plasma = None + + @property + def plasma(self): + if self._plasma is None and not self.disable: + self._plasma = plasma + return self._plasma + + def start_server(self): + if self.plasma is None or self._server is not None: + return + assert self.object_id is None + assert self.path is None + self._server_tmp = tempfile.NamedTemporaryFile() + self.path = self._server_tmp.name + self._server = subprocess.Popen( + ["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path] + ) + + @property + def client(self): + if self._client is None: + assert self.path is not None + self._client = self.plasma.connect(self.path, num_retries=200) + return self._client + + def __getstate__(self): + """Called on pickle load""" + if self.plasma is None: + return self.__dict__ + if self.object_id is None: + self.start_server() + self.object_id = self.client.put(self.array) + state = self.__dict__.copy() + del state["array"] + state["_client"] = None + state["_server"] = None + state["_server_tmp"] = None + state["_plasma"] = None + return state + + def __setstate__(self, state): + """Called on pickle save""" + self.__dict__.update(state) + if self.plasma is None: + return + self.array = self.client.get(self.object_id) + + def __del__(self): + if self._server is not None: + self._server.kill() + self._server = None + self._server_tmp.close() + self._server_tmp = None + + +DEFAULT_PLASMA_PATH = "/tmp/plasma" + + +class PlasmaView: + """Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization, + PlasmaView writes to shared memory on instantiation.""" + + def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None): + """ + Args: + array: numpy array to store. This can be read with ``PlasmaView().array`` + split_path: the path whence the data was read, used for hashing + hash_data: other metadata about the array that can be used to create a unique key. + as of writing, the 3 callers in ``TokenBlockDataset`` use:: + hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2) + """ + assert PYARROW_AVAILABLE + assert split_path is not None + if plasma_path is None: + plasma_path = DEFAULT_PLASMA_PATH + + self.path = plasma_path + self.split_path = split_path + self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized. + self._n = None + + self.object_id = self.get_object_id(self.split_path, hash_data) + try: + self.client.put(array, object_id=self.object_id) + except plasma.PlasmaObjectExists: + pass + + @property + def client(self): + if self._client is None: + self._client = plasma.connect(self.path, num_retries=200) + return self._client + + @property + def array(self): + """Fetch a read only view of an np.array, stored in plasma.""" + ret = self.client.get(self.object_id) + return ret + + @staticmethod + def get_object_id(split_path: str, hash_data: Hashable): + """Returns plasma.ObjectID from hashing split_path and object_num.""" + hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20) + harg = json.dumps(hash_data).encode("utf-8") + hash.update(harg) + return plasma.ObjectID(hash.digest()) + + def __getstate__(self): + """Called on pickle save""" + self.disconnect() + state = self.__dict__.copy() + assert state["_client"] is None + assert "object_id" in state + return state + + def __setstate__(self, state): + """Called on pickle load""" + self.__dict__.update(state) + + def __del__(self): + self.disconnect() + + def disconnect(self): + if self._client is not None: + self._client.disconnect() + self._client = None + + def __len__(self): + """Save reads by caching len""" + if self._n is None: + self._n = len(self.array) + return self._n + + +GB100 = (1024**3) * 100 + + +class PlasmaStore: + def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100): + + self.server = self.start(path, nbytes) + + def __del__(self): + self.server.kill() + + @staticmethod + def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen: + if not PYARROW_AVAILABLE: + raise ImportError("please run pip install pyarrow to use --use_plasma_view") + # best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm + _server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path]) + plasma.connect(path, num_retries=200) # If we can't connect we fail immediately + return _server \ No newline at end of file