Source code for plumbum.machines.env

from __future__ import annotations

__lazy_modules__ = {"contextlib", "copy"}

import copy
import os
import typing
from contextlib import contextmanager

if typing.TYPE_CHECKING:
    from collections.abc import (
        Generator,
        ItemsView,
        Iterable,
        Iterator,
        KeysView,
        MutableMapping,
        ValuesView,
    )
    from typing import Any, Callable, SupportsIndex

    from plumbum.path.base import Path

AnyPath = typing.TypeVar("AnyPath", bound="Path")
AnyPath_co = typing.TypeVar("AnyPath_co", bound="Path", covariant=True)


[docs] class EnvPathList(list[AnyPath], typing.Generic[AnyPath]): __slots__ = ("__weakref__", "_path_factory", "_pathsep")
[docs] def __init__(self, path_factory: Callable[[str], AnyPath], pathsep: str) -> None: super().__init__() self._path_factory = path_factory self._pathsep: str = pathsep
[docs] def append(self, path: str) -> None: list.append(self, self._path_factory(path))
[docs] def extend(self, paths: Iterable[str]) -> None: list.extend(self, (self._path_factory(p) for p in paths))
[docs] def insert(self, index: SupportsIndex, path: str) -> None: list.insert(self, index, self._path_factory(path))
[docs] def index(self, path: str) -> int: # type: ignore[override] return list.index(self, self._path_factory(path))
[docs] def __contains__(self, path: object) -> bool: return list.__contains__(self, self._path_factory(path)) # type: ignore[arg-type]
[docs] def remove(self, path: str) -> None: list.remove(self, self._path_factory(path))
def update(self, text: str) -> None: self[:] = [self._path_factory(p) for p in text.split(self._pathsep)] def join(self) -> str: return self._pathsep.join(str(p) for p in self)
class _PathFactory(typing.Protocol[AnyPath_co]): def __call__(self, *args: str) -> AnyPath_co: ...
[docs] class BaseEnv(typing.Generic[AnyPath]): """The base class of LocalEnv and RemoteEnv""" __slots__ = ("__weakref__", "_curr", "_path", "_path_factory") CASE_SENSITIVE = True
[docs] def __init__( self, path_factory: _PathFactory[AnyPath], pathsep: str, *, _curr: MutableMapping[str, str], ) -> None: self._curr = _curr self._path_factory = path_factory self._path = EnvPathList[AnyPath](path_factory, pathsep) self._update_path()
def _update_path(self) -> None: self._path.update(self.get("PATH", ""))
[docs] @contextmanager def __call__(self, *args: Any, **kwargs: Any) -> Generator[None, None, None]: """A context manager that can be used for temporal modifications of the environment. Any time you enter the context, a copy of the old environment is stored, and then restored, when the context exits. :param args: Any positional arguments for ``update()`` :param kwargs: Any keyword arguments for ``update()`` """ prev = copy.copy(self._curr) self.update(*args, **kwargs) try: yield finally: self._curr = prev self._update_path()
[docs] def __iter__(self) -> Iterator[tuple[str, Any]]: """Returns an iterator over the items ``(key, value)`` of current environment (like dict.items)""" return iter(self._curr.items())
__hash__ = None # type: ignore[assignment]
[docs] def __len__(self) -> int: """Returns the number of elements of the current environment""" return len(self._curr)
[docs] def __contains__(self, name: str) -> bool: """Tests whether an environment variable exists in the current environment""" return (name if self.CASE_SENSITIVE else name.upper()) in self._curr
[docs] def __getitem__(self, name: str) -> str: """Returns the value of the given environment variable from current environment, raising a ``KeyError`` if it does not exist""" return self._curr[name if self.CASE_SENSITIVE else name.upper()]
[docs] def keys(self) -> KeysView[str]: """Returns the keys of the current environment (like dict.keys)""" return self._curr.keys()
[docs] def items(self) -> ItemsView[str, str]: """Returns the items of the current environment (like dict.items)""" return self._curr.items()
[docs] def values(self) -> ValuesView[str]: """Returns the values of the current environment (like dict.values)""" return self._curr.values()
@typing.overload def get(self, name: str, default: None = ...) -> str | None: ... @typing.overload def get(self, name: str, default: str) -> str: ...
[docs] def get(self, name: str, default: str | None = None) -> str | None: """Returns the keys of the current environment (like dict.keys)""" return self._curr.get((name if self.CASE_SENSITIVE else name.upper()), default)
[docs] def __delitem__(self, name: str) -> None: """Deletes an environment variable from the current environment""" name = name if self.CASE_SENSITIVE else name.upper() del self._curr[name] if name == "PATH": self._update_path()
[docs] def __setitem__(self, name: str, value: str) -> None: """Sets/replaces an environment variable's value in the current environment""" name = name if self.CASE_SENSITIVE else name.upper() self._curr[name] = value if name == "PATH": self._update_path()
[docs] def pop(self, name: str, *default: str) -> str | None: """Pops an element from the current environment (like dict.pop)""" name = name if self.CASE_SENSITIVE else name.upper() res = self._curr.pop(name, *default) if name == "PATH": self._update_path() return res
[docs] def clear(self) -> None: """Clears the current environment (like dict.clear)""" self._curr.clear() self._update_path()
[docs] def update(self, *args: Any, **kwargs: Any) -> None: """Updates the current environment (like dict.update)""" self._curr.update(*args, **kwargs) if not self.CASE_SENSITIVE: for k, v in list(self._curr.items()): self._curr[k.upper()] = v self._update_path()
[docs] def getdict(self) -> dict[str, str]: """Returns the environment as a real dictionary""" self._curr["PATH"] = self.path.join() return {k: str(v) for k, v in self._curr.items()}
@property def path(self) -> EnvPathList[AnyPath]: """The system's ``PATH`` (as an easy-to-manipulate list)""" return self._path @property def home(self) -> Path | None: """Get or set the home path""" if "HOME" in self: return self._path_factory(self["HOME"]) if "USERPROFILE" in self: # pragma: no cover return self._path_factory(self["USERPROFILE"]) if "HOMEPATH" in self: # pragma: no cover return self._path_factory(self.get("HOMEDRIVE", ""), self["HOMEPATH"]) return None @home.setter def home(self, p: Path) -> None: if "HOME" in self: self["HOME"] = str(p) elif "USERPROFILE" in self: # pragma: no cover self["USERPROFILE"] = str(p) elif "HOMEPATH" in self: # pragma: no cover self["HOMEPATH"] = str(p) else: # pragma: no cover self["HOME"] = str(p) @property def user(self) -> str | None: """Return the user name, or ``None`` if it is not set""" # adapted from getpass.getuser() for name in ("LOGNAME", "USER", "LNAME", "USERNAME"): # pragma: no branch if name in self: return self[name] try: # POSIX only import pwd except ImportError: return None return pwd.getpwuid(os.getuid())[0] # @UndefinedVariable
__all__ = [ "BaseEnv", "EnvPathList", ] def __dir__() -> list[str]: return list(__all__)