Source code for deepmd.jax.utils.update_sel
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Iterator,
)
from contextlib import (
contextmanager,
)
from typing import (
Any,
)
from deepmd.jax.utils.neighbor_stat import (
NeighborStat,
)
from deepmd.utils.update_sel import (
BaseUpdateSel,
)
[docs]
class UpdateSel(BaseUpdateSel):
@property
[docs]
def neighbor_stat(self) -> type[NeighborStat]:
return NeighborStat
[docs]
def _get_update_sel_descriptors() -> tuple[type[Any], ...]:
import deepmd.dpmodel.descriptor as _dpmodel_descriptor # noqa: F401
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
return tuple(
{
descriptor_cls
for descriptor_cls in BaseDescriptor.get_plugins().values()
if hasattr(descriptor_cls, "_update_sel_cls")
}
)
@contextmanager
[docs]
def use_jax_update_sel() -> Iterator[None]:
"""Use JAX neighbor statistics in dpmodel descriptor update_sel methods."""
descriptor_classes = _get_update_sel_descriptors()
saved_update_sel = {
descriptor_cls: descriptor_cls.__dict__.get("_update_sel_cls", _MISSING)
for descriptor_cls in descriptor_classes
}
try:
for descriptor_cls in descriptor_classes:
descriptor_cls._update_sel_cls = UpdateSel
yield
finally:
for descriptor_cls, update_sel_cls in saved_update_sel.items():
if update_sel_cls is _MISSING:
del descriptor_cls._update_sel_cls
else:
descriptor_cls._update_sel_cls = update_sel_cls