Source code for myqueue.caching

"""Simple caching function implementation using JSON."""
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Sequence, TypeVar
from functools import lru_cache, wraps


T = TypeVar('T')


class CacheFileNotFoundError(FileNotFoundError):
    """JSON cache file not found."""


[docs] def json_cached_function(function: Callable[..., T], name: str, args: Sequence[Any], kwargs: dict[str, Any]) -> Callable[..., T]: """Add file-caching to function. The decorated function will write its result in JSON format to a file called <name>.result. """ path = Path(f'{name}.result') @wraps(function) def new_func(only_read_from_cache: bool = False) -> T: """A caching function. If *only_read_from_cache* is True then an CacheFileNotFoundError exception will be raised if the file does not exist. """ if path.is_file(): return decode(path.read_text(encoding='utf-8')) if only_read_from_cache: raise CacheFileNotFoundError result = function(*args, **kwargs) if mpi_world().rank == 0: path.write_text(encode(result), encoding='utf-8') return result return new_func
class MPIWorld: """A no-MPI implementation.""" rank: int = 0 @lru_cache() def mpi_world() -> MPIWorld: """Find and return a world object with a rank attribute.""" import sys mod = sys.modules.get('mpi4py') if mod: return mod.MPI.COMM_WORLD # type: ignore mod = sys.modules.get('_gpaw') if hasattr(mod, 'Communicator'): return mod.Communicator() # type: ignore mod = sys.modules.get('_asap') if hasattr(mod, 'Communicator'): return mod.Communicator() # type: ignore return MPIWorld()
[docs] class Encoder(json.JSONEncoder): """Encode complex, datetime, Path and ndarray objects. >>> import numpy as np >>> Encoder().encode(1+2j) '{"__complex__": [1.0, 2.0]}' >>> Encoder().encode(datetime(1969, 11, 11, 0, 0)) '{"__datetime__": "1969-11-11T00:00:00"}' >>> Encoder().encode(Path('abc/123.xyz')) '{"__path__": "abc/123.xyz"}' >>> Encoder().encode(np.array([1., 2.])) '{"__ndarray__": [1.0, 2.0]}' """ def default(self, obj: Any) -> Any: if isinstance(obj, complex): return {'__complex__': [obj.real, obj.imag]} if isinstance(obj, datetime): return {'__datetime__': obj.isoformat()} if isinstance(obj, Path): return {'__path__': str(obj)} if hasattr(obj, '__array__'): if obj.dtype == complex: dct = {'__ndarray__': obj.view(float).tolist(), 'dtype': 'complex'} else: dct = {'__ndarray__': obj.tolist()} if obj.dtype not in [int, float]: dct['dtype'] = obj.dtype.name if obj.size == 0: dct['shape'] = obj.shape return dct return json.JSONEncoder.default(self, obj)
encode = Encoder().encode
[docs] def object_hook(dct: dict[str, Any]) -> Any: """Decode complex, datetime, Path and ndarray representations. >>> object_hook({'__complex__': [1.0, 2.0]}) (1+2j) >>> object_hook({'__datetime__': '1969-11-11T00:00:00'}) datetime.datetime(1969, 11, 11, 0, 0) >>> object_hook({'__path__': 'abc/123.xyz'}) PosixPath('abc/123.xyz') >>> object_hook({'__ndarray__': [1.0, 2.0]}) array([1., 2.]) """ data = dct.get('__complex__') if data is not None: return complex(*data) data = dct.get('__datetime__') if data is not None: return datetime.fromisoformat(data) data = dct.get('__path__') if data is not None: return Path(data) data = dct.get('__ndarray__') if data is not None: import numpy as np dtype = dct.get('dtype') if dtype == 'complex': array = np.array(data, dtype=float).view(complex) else: array = np.array(data, dtype=dtype) shape = dct.get('shape') if shape is not None: array.shape = shape return array return dct
[docs] def decode(text: str) -> Any: """Convert JSON to object(s).""" return json.loads(text, object_hook=object_hook)