diff --git a/dpdata/lmdb/format.py b/dpdata/lmdb/format.py index 9b518be6..2e7ae3c4 100644 --- a/dpdata/lmdb/format.py +++ b/dpdata/lmdb/format.py @@ -2,15 +2,10 @@ import os -import lmdb -import msgpack -import msgpack_numpy as m import numpy as np from dpdata.format import Format -m.patch() - class LMDBError(Exception): """Base class for LMDB errors.""" @@ -63,6 +58,11 @@ class LMDBFormat(Format): >>> loaded_multi_systems = dpdata.MultiSystems.from_file("my_multi_system_db.lmdb", fmt="lmdb") """ + def __init__(self, *args, **kwargs) -> None: + import msgpack_numpy as m + + m.patch() + def to_multi_systems( self, formulas, directory, map_size=1000000000, frame_idx_fmt="012d", **kwargs ): @@ -86,6 +86,9 @@ def to_multi_systems( tuple (self, formula) to be used by to_system """ + import lmdb + import msgpack + self._frame_idx_fmt = frame_idx_fmt self._global_frame_idx = 0 self._system_info = [] @@ -105,6 +108,8 @@ def to_multi_systems( self._txn = None def _dump_to_txn(self, data, txn, formula, dtypes): + import msgpack + from dpdata.data_type import Axis nframes = data["coords"].shape[0] @@ -209,6 +214,9 @@ def from_multi_systems(self, file_name, map_size=1000000000, **kwargs): dict data dictionary for each system """ + import lmdb + import msgpack + from dpdata.data_type import Axis, DataType from dpdata.system import LabeledSystem, System