Source code for ipie.utils.misc


# Copyright 2022 The ipie Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors: Fionn Malone <fionn.malone@gmail.com>
#          Joonho Lee
#

"""Various useful routines maybe not appropriate elsewhere"""

import os
import socket
import subprocess
import sys
import time
import types
from functools import reduce

import numpy
import scipy.sparse


[docs]def is_cupy(obj): t = str(type(obj)) cond = "cupy" in t return cond
[docs]def to_numpy(obj): t = str(type(obj)) cond = "cupy" in t if cond: import cupy return cupy.asnumpy(obj) else: return obj
[docs]def get_git_info(): """Return git info. Adapted from: http://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script Returns ------- sha1 : string git hash with -dirty appended if uncommitted changes. branch : string Current branch local_mod : list of strings List of locally modified files tracked and untracked. """ under_git = True try: src = os.path.dirname(__file__) + "/../../" sha1 = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=src, stderr=subprocess.DEVNULL).strip() suffix = subprocess.check_output( ["git", "status", "-uno", "--porcelain", "./ipie"], cwd=src ).strip() local_mods = subprocess.check_output( ["git", "status", "--porcelain", "./ipie"], cwd=src ).strip().decode('utf-8').split() branch = subprocess.check_output( ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=src ).strip() except subprocess.CalledProcessError as e: under_git = False except Exception as error: suffix = False print(f"couldn't determine git hash : {error}") sha1 = "none".encode() local_mods = [] if under_git: if suffix: return sha1.decode("utf-8") + "-dirty", branch.decode("utf-8"), local_mods else: branch = subprocess.check_output( ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=src ).strip() return sha1.decode("utf-8"), branch.decode("utf_8"), local_mods else: return None, None, []
[docs]def is_h5file(obj): t = str(type(obj)) cond = "h5py" in t return cond
[docs]def is_class(obj): cond = hasattr(obj, "__class__") and ( ("__dict__") in dir(obj) and not isinstance(obj, types.FunctionType) and not is_h5file(obj) ) return cond
[docs]def serialise(obj, verbose=0): obj_dict = {} if isinstance(obj, dict): items = obj.items() else: items = obj.__dict__.items() for k, v in items: if isinstance(v, scipy.sparse.csr_matrix): pass elif isinstance(v, scipy.sparse.csc_matrix): pass elif is_class(v): # Object obj_dict[k] = serialise(v, verbose) elif isinstance(v, dict): obj_dict[k] = serialise(v) elif isinstance(v, types.FunctionType): # function if verbose == 1: obj_dict[k] = str(v) elif hasattr(v, "__self__"): # unbound function if verbose == 1: obj_dict[k] = str(v) elif k == "estimates" or k == "global_estimates": pass elif k == "walkers": obj_dict[k] = [str(x) for x in v][0] elif isinstance(v, numpy.ndarray): if verbose == 3: if v.dtype == complex: obj_dict[k] = [v.real.tolist(), v.imag.tolist()] else: obj_dict[k] = (v.tolist(),) elif verbose == 2: if len(v.shape) == 1: if v[0] is not None and v.dtype == complex: obj_dict[k] = [[v.real.tolist(), v.imag.tolist()]] else: obj_dict[k] = (v.tolist(),) elif len(v.shape) == 1: if v[0] is not None and numpy.linalg.norm(v) > 1e-8: if v.dtype == complex: obj_dict[k] = [[v.real.tolist(), v.imag.tolist()]] else: obj_dict[k] = (v.tolist(),) elif k == "store": if verbose == 1: obj_dict[k] = str(v) elif isinstance(v, (int, float, bool, str)): obj_dict[k] = v elif isinstance(v, complex): obj_dict[k] = v.real elif v is None: obj_dict[k] = v elif is_h5file(v): if verbose == 1: obj_dict[k] = v.filename else: pass return obj_dict
[docs]class dotdict(dict): __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__
[docs]def update_stack(stack_size, num_slices, name="stack", verbose=False): lower_bound = min(stack_size, num_slices) upper_bound = min(stack_size, num_slices) while (num_slices // lower_bound) * lower_bound < num_slices: lower_bound -= 1 while (num_slices // upper_bound) * upper_bound < num_slices: upper_bound += 1 if (stack_size - lower_bound) <= (upper_bound - stack_size): stack_size = lower_bound else: stack_size = upper_bound if verbose: print("# Initial {} upper_bound is {}".format(name, upper_bound)) print("# Initial {} lower_bound is {}".format(name, lower_bound)) print("# Adjusted {} size is {}".format(name, stack_size)) return stack_size
[docs]def merge_dicts(a, b, path=None): if path is None: path = [] for key in b: if key in a: if isinstance(a[key], dict) and isinstance(b[key], dict): merge_dicts(a[key], b[key], path + [str(key)]) elif a[key] == b[key]: pass # same leaf value else: raise Exception("Conflict at %s" % ".".join(path + [str(key)])) else: a[key] = b[key] return a
[docs]def get_from_dict(d, k): """Get value from nested dictionary. Taken from: https://stackoverflow.com/questions/28225552/is-there-a-recursive-version-of-the-dict-get-built-in Parameters ---------- d : dict k : list List specifying key to extract. Returns ------- value : Return type or None. """ try: return reduce(dict.get, k, d) except TypeError: # Value not found. return None
[docs]def get_numeric_names(d): names = [] size = 0 for k, v in d.items(): if isinstance(v, (numpy.ndarray)): names.append(k) size += v.size elif isinstance(v, (int, float, complex)): names.append(k) size += 1 elif isinstance(v, list): names.append(k) for l in v: if isinstance(l, (numpy.ndarray)): size += l.size elif isinstance(l, (int, float, complex)): size += 1 return names, size
[docs]def get_node_mem(): try: return os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE") / 1024**3.0 except: return 0.0
[docs]def timeit(func): def wrapper(*args, **kwargs): start = time.time() res = func(*args, **kwargs) end = time.time() print(" # Time : {} ".format(end - start)) return res return wrapper