--- a +++ b/scvae/analyses/images.py @@ -0,0 +1,93 @@ +# ======================================================================== # +# +# Copyright (c) 2017 - 2020 scVAE authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ======================================================================== # + +import os + +import numpy +import PIL +import scipy.sparse + +from scvae.analyses.figures import saving + +IMAGE_EXTENSION = ".png" +DEFAULT_NUMBER_OF_RANDOM_EXAMPLES_FOR_COMBINED_IMAGES = 100 + + +def combine_images_from_data_set(data_set, indices=None, + number_of_random_examples=None, name=None): + + image_name = saving.build_figure_name("random_image_examples", name) + random_state = numpy.random.RandomState(13) + + if indices is not None: + n_examples = len(indices) + if number_of_random_examples is not None: + n_examples = min(n_examples, number_of_random_examples) + indices = random_state.permutation(indices)[:n_examples] + else: + if number_of_random_examples is not None: + n_examples = number_of_random_examples + else: + n_examples = DEFAULT_NUMBER_OF_RANDOM_EXAMPLES_FOR_COMBINED_IMAGES + indices = random_state.permutation( + data_set.number_of_examples)[:n_examples] + + if n_examples == 1: + image_name = saving.build_figure_name("image_example", name) + else: + image_name = saving.build_figure_name("image_examples", name) + + width, height = data_set.feature_dimensions + + examples = data_set.values[indices] + if scipy.sparse.issparse(examples): + examples = examples.A + examples = examples.reshape(n_examples, width, height) + + column = int(numpy.ceil(numpy.sqrt(n_examples))) + row = int(numpy.ceil(n_examples / column)) + + image = numpy.zeros((row * width, column * height)) + + for m in range(n_examples): + c = int(m % column) + r = int(numpy.floor(m / column)) + rows = slice(r*width, (r+1)*width) + columns = slice(c*height, (c+1)*height) + image[rows, columns] = examples[m] + + return image, image_name + + +def save_image(image, name, directory): + + if not os.path.exists(directory): + os.makedirs(directory) + + minimum = image.min() + maximum = image.max() + if 0 < minimum and minimum < 1 and 0 < maximum and maximum < 1: + rescaled_image = 255 * image + else: + rescaled_image = (255 / (maximum - minimum) * (image - minimum)) + + image = PIL.Image.fromarray(rescaled_image.astype(numpy.uint8)) + + name += IMAGE_EXTENSION + image_path = os.path.join(directory, name) + image.save(image_path)