ok

Mini Shell

Direktori : /opt/imunify360/venv/lib64/python3.11/site-packages/defence360agent/files/
Upload File :
Current File : //opt/imunify360/venv/lib64/python3.11/site-packages/defence360agent/files/__init__.py

"""Utilities for managing local file storage synchronised with a remote
server.

Files are divided into types: signatures, modsecurity bundles, ip white
lists, etc. Each type is represented by an Index instance.

Index has a local subdirectory and a description that contains its
files' metadata used to decide if the update is necessary.
"""

import asyncio
import datetime as DT
import hashlib
import http.client
import io
import json
import math
import os
import pathlib
import random
import shutil
import socket
import time
import zipfile
import urllib.error
import urllib.request
from collections import defaultdict, namedtuple
from contextlib import ExitStack, suppress, contextmanager
from email.utils import formatdate, parsedate_to_datetime
from gzip import GzipFile
from itertools import chain
from logging import getLogger
from typing import (
    Any,
    BinaryIO,
    Dict,
    Iterable,
    List,
    Optional,
    Set,
    Tuple,
    Union,
)
from urllib.parse import urlparse


from defence360agent.contracts import config
from defence360agent.contracts.license import LicenseCLN
from defence360agent.utils import file_hash, retry_on, run_with_umask
from defence360agent.utils.common import rate_limit, HOUR
from defence360agent.utils.threads import to_thread
from .hooks import default_hook

logger = getLogger(__name__)

# static file types
EULA = "eula"
SIGS = "sigs"  # malware signatures
REALTIME_AV_CONF = "realtime-av-conf"

FILES_DIR = pathlib.Path("/var/imunify360/files")
BASE_URL = "https://files.imunify360.com/static/"

# chunk size for network and file operations, in bytes
_BUFSIZE = 32 * 1024

_MAX_TRIES_FOR_DOWNLOAD = 10
_TIMEOUT_MULTIPLICATOR = 0.025
"""
>>> _MAX_TRIES_FOR_DOWNLOAD = 10
>>> _TIMEOUT_MULTIPLICATOR = 0.025
>>> [(1 << i) * _TIMEOUT_MULTIPLICATOR for i in range(1, _MAX_TRIES_FOR_DOWNLOAD)]  # noqa
[0.05, 0.1, 0.2, 0.4, 0.8, 1.6, 3.2, 6.4, 12.8]
"""

#: sentinel: mtime for a missing/never modified file
_NEVER = -math.inf
# https://github.com/python/typing/issues/182
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]


class IntegrityError(RuntimeError):
    """Raised when on disk content does not match hashes in description.json"""


class UpdateError(RuntimeError):
    """Raised on other errors during files update.

    Possible reasons are:

    * server returns non 200 status;
    * hash mismatched between downloaded content and description.json;
    * urllib errors;
    * JSON decoding errors;
    * errors while writing to disk.
    """


async def _log_failed_update(exc, i):
    logger.warning(
        "Files update failed with error: {err}, try: {try_}".format(
            err=exc, try_=i
        )
    )
    # exponential backoff
    await asyncio.sleep(random.randrange(1 << i) * _TIMEOUT_MULTIPLICATOR)


def _open_with_mode(path: os.PathLike, mode: int) -> BinaryIO:
    """Open file at `path` using permission `mode` for writing in binary mode
    and return file object."""
    with run_with_umask(0):
        fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode)
    return os.fdopen(fd, "wb")


def _fetch_json_sync(url, timeout) -> JSONType:
    with _fetch_url(url, timeout=timeout) as response:
        return json.load(
            io.TextIOWrapper(
                response["file"],
                encoding=response["headers"].get_content_charset("utf-8"),
            )
        )


@retry_on(
    UpdateError, on_error=_log_failed_update, max_tries=_MAX_TRIES_FOR_DOWNLOAD
)
async def _fetch_json(url: str, timeout) -> JSONType:
    """Download and decode JSON from *url*.

    Return decoded JSON.  Raise UpdateError:

    * HTTP response status code is not 200;
    * Unicode or JSON decoding fails;
    * on time outs during HTTP request;
    * on other HTTP errors.
    """
    loop = asyncio.get_event_loop()
    try:
        return await loop.run_in_executor(None, _fetch_json_sync, url, timeout)
    except (UnicodeDecodeError, json.JSONDecodeError) as e:
        raise UpdateError("json decode error [{}] for url {}".format(e, url))
    except socket.timeout:
        raise UpdateError("request to {} timed out".format(url))
    except ConnectionResetError:
        raise UpdateError("request to {} reset".format(url))
    except EOFError as e:
        raise UpdateError(
            f"eof error while updating files, url: {url}, err: {e}"
        )
    except (http.client.HTTPException, urllib.error.URLError) as e:
        raise UpdateError(
            "urllib/http error while updating files, url: {}, err: {}".format(
                url, e
            )
        )
    except OSError as e:
        raise UpdateError(f"Can't fetch {url}, reason: {e}")


def _perform_http_head_sync(  # NOSONAR pylint:W0102
    url: str, timeout: float, *, headers={}
):
    """Perform HEAD http request to *url* with *timeout* & *headers*."""
    req = urllib.request.Request(
        url,
        headers={
            "Imunify-Server-Id": LicenseCLN.get_server_id() or "",
            **headers,
        },
        method="HEAD",
    )
    with urllib.request.urlopen(req, timeout=timeout) as r:
        return r.code, r.headers


@retry_on(
    UpdateError, on_error=_log_failed_update, max_tries=_MAX_TRIES_FOR_DOWNLOAD
)
async def _need_to_download(
    url: str, current_mtime: float, timeout: float
) -> bool:
    """Check if we need to download description.json file:
    - perform HEAD request if local file exists and older return True
    otherwise return False
    """
    if current_mtime is _NEVER:  # file has never been updated
        return True  # need to download it
    formatted_mtime = formatdate(current_mtime, usegmt=True)
    try:
        code, headers = await to_thread(
            _perform_http_head_sync,
            url,
            timeout,
            headers={"If-Modified-Since": formatted_mtime},
        )
    except socket.timeout:
        raise UpdateError("request to {} timed out".format(url))
    except ConnectionResetError:
        raise UpdateError("request to {} reset".format(url))
    except (http.client.HTTPException, urllib.error.URLError) as e:
        if hasattr(e, "code") and e.code == 304:  # NOSONAR file not modified
            return False  # no need to re-download
        raise UpdateError(
            "urllib/http error while updating files, url: {}, err: {}".format(
                url, e
            )
        )
    else:
        if code != 200:
            raise UpdateError(
                f"Unexpected http code {code!r} for {url}"
            )  # pragma: no cover
        with suppress(Exception):
            last_mtime = parsedate_to_datetime(
                headers["Last-Modified"]
            ).timestamp()
            if last_mtime <= current_mtime:  # file on the server NOT newer
                logger.warning(
                    "Got code %r, but last modification date %s is earlier"
                    " than or equal to the date provided in the"
                    " If-Modified-Since header, the origin server SHOULD"
                    " generate a 304 (Not Modified) response [rfc7232]."
                    " Here's curl cmd:\ncurl -s -I -w '%%{http_code}' -H"
                    " 'If-Modified-Since: %s' '%s'",
                    code,
                    headers["Last-Modified"],
                    formatted_mtime,
                    url,
                )

        return True  # file has been modified since current mtime, re-download


@contextmanager
def _fetch_url(url: str, *, timeout: float, compress=True):
    """
    Fetch *url* as binary file.
    If *compress* is true, ungzipping is done automatically
    if necessary.
    """
    parameters = {}
    if timeout is not None:  # use default timeout instead None
        parameters["timeout"] = timeout

    req_headers = {"Imunify-Server-Id": LicenseCLN.get_server_id() or ""}
    if compress:
        # express preference for gzip but don't forbid identity encoding
        req_headers.update({"Accept-Encoding": "gzip"})
    req = urllib.request.Request(url, headers=req_headers)

    with urllib.request.urlopen(
        req, **parameters
    ) as response, ExitStack() as stack:
        # check whether response is gzipped regardless *compress* arg
        gzipped = response.headers.get("Content-Encoding") == "gzip"
        if (
            compress
            and not gzipped
            and response.headers.get("Content-Type") != "application/zip"
        ):
            logger.info(
                "Requested gzip but got Content-Encoding=%r."
                " Read response as is [identity]. Headers: %s,"
                " as curl cmd:\ncurl -Is -H 'Accept-Encoding: gzip' '%s'",
                response.headers.get("Content-Encoding"),
                response.headers.items(),
                url,
            )
        yield {
            "file": (
                stack.enter_context(GzipFile(fileobj=response))
                if gzipped
                else response
            ),
            "headers": response.headers,
        }


def _fetch_n_md5sum_url(
    url, dest_file: BinaryIO, timeout, *, compress, md5sum
):
    """
    Fetch *url* to *dest_file* and return its md5sum.
    Raise *urllib.error.ContentTooShortError* if the downloaded file
    has unexpected length.
    """
    md5 = hashlib.md5()
    initial_file_offset = dest_file.tell()
    with _fetch_url(url, timeout=timeout, compress=compress) as response:
        while chunk := response["file"].read(_BUFSIZE):  # NOSONAR
            md5.update(chunk)
            dest_file.write(chunk)
    if not response["headers"].get("Content-Encoding") == "gzip":
        # Content-Length is compressed size
        # -> no point in comparing with the uncompressed result
        file_length = dest_file.tell() - initial_file_offset
        # make sure the file has been downloaded correctly
        # Content-Length may not be set if exist header
        # Transfer-Encoding: chunked
        content_length_header = response["headers"].get("Content-Length", None)
        if content_length_header is not None:
            expected_file_length = int(content_length_header)
            if expected_file_length != file_length:
                raise urllib.error.ContentTooShortError(
                    message="{got} bytes read, {diff} more expected".format(
                        got=file_length,
                        diff=expected_file_length - file_length,
                    ),
                    content=None,
                )
    got_md5sum = md5.hexdigest()
    if md5sum is not None and got_md5sum != md5sum:
        raise UpdateError(
            f"content fetched from {url} does not match hash:"
            f" expected={md5sum}, got={got_md5sum}"
        )
    return got_md5sum


@retry_on(
    UpdateError, on_error=_log_failed_update, max_tries=_MAX_TRIES_FOR_DOWNLOAD
)
async def _fetch_and_save(
    url: str,
    dest_path: os.PathLike,
    timeout,
    *,
    dest_mode: int,
    compress=True,
    md5sum=None,
) -> str:
    """Fetch bytes from `url`, save them to `dest_path`,
    and return md5 checksum of downloaded content.

    Raise UpdateError:

    * HTTP response status code is not 200;
    * on time outs during HTTP request;
    * on other HTTP errors.
    """
    try:
        with _open_with_mode(dest_path, dest_mode) as dest_file:
            return await to_thread(
                _fetch_n_md5sum_url,
                url,
                dest_file,
                timeout,
                compress=compress,
                md5sum=md5sum,
            )
    except socket.timeout:
        raise UpdateError("request to {} timed out".format(url))
    except ConnectionResetError:
        raise UpdateError("request to {} reset".format(url))
    except EOFError as e:
        raise UpdateError(
            f"eof error while updating files, url: {url}, err: {e}"
        )
    except (http.client.HTTPException, urllib.error.URLError) as e:
        raise UpdateError(
            "urllib/http error while updating files, url: {}, err: {}".format(
                url, e
            )
        )
    except OSError as e:
        raise UpdateError(f"Can't fetch {url} to {dest_path}, reason: {e}")


_Item = namedtuple("_Item", ["url", "md5sum"])


def _items(data: Any) -> Set[_Item]:
    """Return a set of _Item for easy manipulation."""
    return {_Item(item["url"], item["md5sum"]) for item in data["items"]}


def check_mode_dirs(dirname, dir_perm, file_perm):
    """Check and change file/dir modes recursively.

    Starting at dirname, change all inner directory permissions to dir_perm,
    file permissions to file_perm
    """

    def _os_chmod(file_dir_path, permission):
        try:
            current_mode = os.lstat(file_dir_path).st_mode & 0o777
            if current_mode != permission and not os.path.islink(
                file_dir_path
            ):
                logger.warning(
                    "Fixing wrong permission to file/dir"
                    " %s [%s] expected [%s] (not symlink)",
                    file_dir_path,
                    oct(current_mode),
                    oct(permission),
                )
                os.chmod(file_dir_path, permission)
        except PermissionError:
            logger.error(
                "Failed to change permission to file %s", file_dir_path
            )

    _os_chmod(dirname, dir_perm)
    for path, dirs, files in os.walk(dirname):
        for directory in dirs:
            _os_chmod(os.path.join(path, directory), dir_perm)
        for name in files:
            _os_chmod(os.path.join(path, name), file_perm)


def _fix_directory_structure(
    description_path: pathlib.Path, files_path: pathlib.Path = FILES_DIR
) -> None:
    """
    Try to fix the structure of /var/imunify360/files/ when
    NotADirectoryError happens.
    It indicates that some part in the path is a file:
    /var/imunify360/files/sigs <- is a file
    => open("/var/imunify360/files/sigs/v1/description.json") will fail.
    We try to rectify it by deleting the file but up to FILES_DIR.
    """
    assert files_path in description_path.parents
    _dir = description_path.parent
    topmost_dir = _dir
    while _dir != files_path:
        if _dir.is_file():
            _dir.unlink(missing_ok=True)
            topmost_dir.mkdir(parents=True, exist_ok=True)
            break
        _dir = _dir.parent


class Index:
    # one lock is shared via Index and that allows
    # more than one instance of Index to co-exist
    _lock = defaultdict(asyncio.Lock)  # type: Dict[Any, asyncio.Lock]
    _HOOKS = defaultdict(set)  # type: Dict[str, Set[Any]]
    _PATHS = {}  # type: Dict[str, str]
    _PERMS = {}  # type: Dict[str, Dict[str, int]]
    _TYPES = set()  # type: Set[str]
    _ESSENTIAL_TYPES = set()  # type: Set[str]
    _ALL_ZIP_SUPPORT = {}  # type: Dict[str, bool]
    _URL_PATH_PREFIX = "/static"
    _throttled_log_error = rate_limit(period=4 * HOUR)(logger.error)

    def __init__(self, type_, integrity_check=True):
        """
        :param bool integrity_check: check if last update
            did not break anything (by interrupting it in the middle or
            another programmatic error)
        :raise IntegrityError:
        """
        if type_ not in self._TYPES:
            raise ValueError(
                f"Trying to initiate unregistered file type {type_}. Allowed"
                f" types {self._TYPES}"
            )
        self.type = type_
        self._is_blank = False
        self._json = {"items": []}
        path = self._descriptionfile_path()
        try:
            with open(path) as f:
                self._json = json.load(f)
        except NotADirectoryError:
            Index._throttled_log_error("Path %s has a file in parents", path)
            _fix_directory_structure(pathlib.Path(path), FILES_DIR)
        except (
            FileNotFoundError,
            UnicodeDecodeError,
            json.JSONDecodeError,
        ) as e:
            if integrity_check:
                raise IntegrityError(
                    "cannot read description file {}".format(path)
                ) from e
            self._is_blank = True
        if integrity_check:
            bad_files = self._corrupted_files()
            if len(bad_files):
                raise IntegrityError(
                    "some files are missing or corrupted: {}".format(
                        ", ".join(bad_files)
                    )
                )
        if not self._is_blank:
            self.check_mode_dirs()

    def __eq__(self, other):
        return (
            self.__class__ == other.__class__
            and self.type == other.type
            and self._is_blank == other._is_blank
            and self._json == other._json
        )

    def __repr__(self):  # pragma: no cover
        return (
            f"<{self.__class__.__name__}(type_={self.type})"
            f" is_blank={self._is_blank}, "
            f"json={{<{len(self.items())}"
            " item(s)>}>"
        )

    def validate(self, files_path: os.PathLike) -> None:
        """Whether *files_path* dir may be used for this type's file group.

        :raises: IntegrityError
        """
        logger.info("Validating [%s]: %s", self.type, files_path)

        FileGroup = self._make_file_group(
            files_path
        )  # noqa NOSONAR disable python:S117
        FileGroup(self.type, integrity_check=True)

    def _make_file_group(self, files_path: os.PathLike):
        """
        Return FileGroup class: Index class with local path == *files_path*.
        """

        class FileGroup(self.__class__):
            @classmethod
            def files_path(cls, type_: str) -> str:
                """Return local base path for given file type."""
                assert type_ == self.type
                return os.fspath(files_path)

        return FileGroup

    def check_mode_dirs(self):
        perms = Index._PERMS[self.type]
        check_mode_dirs(
            os.path.normpath(
                os.path.join(FILES_DIR, Index._PATHS[self.type], os.pardir)
            ),
            perms["dir"],
            perms["file"],
        )

    @classmethod
    def add_type(
        cls,
        type_: str,
        relative_path: str,
        dir_perm: int,
        file_perm: int,
        *,
        all_zip: bool = False,
        essential: bool = True,
    ) -> None:
        """Add a type to known file types.

        * relative_path is a relative path to all files for that type.
        * dir_perm is permission mask used to create directories.
        * file_perm is permission mask used to create files.
        * all_zip is a flag which shows whether that type of files can
          be downloaded in all.zip archive. all.zip is expected to be on
          the server.
        * essential is whether the agent can start if there are errors
          updating that type.
        """
        cls._TYPES.add(type_)
        if essential:
            cls._ESSENTIAL_TYPES.add(type_)
        cls._PATHS[type_] = relative_path
        cls._PERMS[type_] = {"dir": dir_perm, "file": file_perm}
        cls._ALL_ZIP_SUPPORT[type_] = all_zip

    @classmethod
    async def essential_files_exist(cls) -> bool:
        """Whether essential files exist.

        Note: the files may be corrupted (integrity check is not performed).
        """
        # use the existence of the description files as a proxy
        return all(
            not Index(type_, integrity_check=False)._is_blank
            for type_ in cls._ESSENTIAL_TYPES
        )

    @classmethod
    def types(cls) -> Set[str]:
        """Return a set of all known files types."""
        return cls._TYPES.copy()

    @classmethod
    def files_path(cls, type_: str) -> str:
        """Return local base path for given file type."""
        return os.path.join(FILES_DIR, cls._PATHS[type_])

    def _descriptionfile_path(self):
        """Return local path for description.json for current index."""
        return os.path.join(self.files_path(self.type), "description.json")

    def _corrupted_files(self) -> Set[str]:
        """Return a set of file paths that are missing or corrupted."""
        bad_files = set()
        for item in _items(self._json):
            path = self.localfilepath(item.url)
            try:
                actual = file_hash(path, hashlib.md5, _BUFSIZE)
            except FileNotFoundError:
                bad_files.add(path)
                continue
            if actual != item.md5sum:
                bad_files.add(path)
        return bad_files

    @classmethod
    def locked(cls, type_):
        """
        usage example:
        >> async with Index.locked(WHITELISTS):
            ...
        """
        return cls._lock[type_]

    def files(self) -> Iterable[str]:
        """Return iterable over all files in index."""
        return (self.localfilepath(item.url) for item in _items(self._json))

    def items(self):
        """Return 'items' field from JSON description."""
        return self._json["items"]

    def _descriptionfile_mtime(self, default=_NEVER) -> float:
        """Return mtime of description file if it exists, otherwise -math.inf"""
        try:
            return os.stat(self._descriptionfile_path()).st_mtime
        except OSError:
            return default

    def _is_outdated(self) -> bool:
        """Return True if last update was too late in the past."""
        _desc_mtime = self._descriptionfile_mtime()
        if not _desc_mtime:
            return True  # pragma: no cover
        return _desc_mtime + config.FilesUpdate.PERIOD < time.time()

    async def is_update_needed(self, timeout: float) -> bool:
        """Return True if update from server is needed for current index."""
        return (
            self._is_blank
            or len(self._corrupted_files()) > 0
            or (
                self._is_outdated()
                and await _need_to_download(
                    self._descriptionfile_url(self.type),
                    self._descriptionfile_mtime(),
                    timeout,
                )
            )
        )

    def _makedirs(self, dirname, dir_mode, exist_ok=False):
        """Create local directory for current index."""
        try:
            with run_with_umask(0):
                os.makedirs(dirname, mode=dir_mode, exist_ok=exist_ok)
        except OSError as e:
            raise UpdateError(str(e)) from e

    async def _update_files(
        self, files_path: pathlib.Path, to_update: Set[_Item], timeout
    ) -> None:
        """
        Fetch files from *to_update* set, verify hashes, save to *files_path*.
        """
        FileGroup = self._make_file_group(
            files_path
        )  # noqa NOSONAR disable python:S117
        fg = FileGroup(self.type, integrity_check=False)
        dir_mode = fg._PERMS[fg.type]["dir"]  # NOSONAR disable python:W0212
        file_mode = fg._PERMS[fg.type]["file"]  # NOSONAR disable python:W0212
        for item in to_update:
            filename = fg.localfilepath(item.url)
            dirname = os.path.dirname(filename)
            if not os.path.isdir(dirname):
                self._makedirs(dirname, dir_mode, exist_ok=False)
            await _fetch_and_save(
                item.url,
                filename,
                timeout,
                dest_mode=file_mode,
                md5sum=item.md5sum,
            )

    def _calculate_changes(
        self, remote_items: Set[_Item]
    ) -> Tuple[Set[_Item], Set[str]]:
        """Figure out what should be updated based on current items,
        file system state and remote items.

        Return tuple of files to fetch and files to delete.
        Files to fetch is a set of _Item.
        Files to delete is a set of file paths."""
        local_items = _items(self._json)
        local_files = {self.localfilepath(item.url) for item in local_items}
        remote_files = {self.localfilepath(item.url) for item in remote_items}
        to_remove = local_files - remote_files

        bad_files = self._corrupted_files()
        local_set = {
            item
            for item in local_items
            if self.localfilepath(item.url) not in bad_files
        }
        to_update = remote_items - local_set
        return to_update, to_remove

    @classmethod
    def _descriptionfile_url(cls, type_: str) -> str:
        """Return remote path for description.json"""
        return "{}{}/description.json".format(BASE_URL, cls._PATHS[type_])

    @classmethod
    def _all_zip_url(cls, type_: str) -> str:
        """Return remote path for all.zip"""
        return "{}{}/all.zip".format(BASE_URL, cls._PATHS[type_])

    @staticmethod
    def _all_zip_cleanup(files_path, all_zip_localpath, remove_files=False):
        try:
            os.unlink(all_zip_localpath)
        except OSError as e:
            logger.warning(
                "failed to remove %s: %s", all_zip_localpath, str(e)
            )
        if remove_files:
            logger.info("Removing old path on all.zip update: %s", files_path)
            shutil.rmtree(files_path, ignore_errors=True)

    @staticmethod
    def _generate_new_path(live_path: pathlib.Path) -> pathlib.Path:
        """Generate new base local path for *live_path* files.

        It should be on the same filesystem partition as
        *live_path* so that the rename would be atomic.

        """
        new_suffix = DT.datetime.utcnow().strftime("_%Y-%m-%dT%H%M%S.%fZ")
        return live_path.with_name(live_path.name + new_suffix)

    async def _run_update_all_zip(self, timeout) -> bool:
        """
        Update current type of files using all.zip archive. Directory with
        current type of files will be cleared and replaced with all.zip
        contents. all.zip is expected to be on the server

        Return whether updated.

        :param timeout:
        :raise UpdateError: if OSError or http error or
                            integrity check error (got wrong data from
                            the server)
        """
        live_path = pathlib.Path(self.files_path(self.type))
        new_path = Index._generate_new_path(live_path)
        archive_path = new_path.with_name(new_path.name + "all.zip")

        file_mode = self._PERMS[self.type]["file"]
        dir_mode = self._PERMS[self.type]["dir"]
        all_zip_url = self._all_zip_url(self.type)
        with ExitStack() as rollback_stack:
            # make new download dir
            self._makedirs(new_path, dir_mode, exist_ok=False)
            rollback_stack.callback(
                Index._all_zip_cleanup,
                new_path,
                archive_path,
                remove_files=True,
            )

            # download the archive
            # TODO: DEF-16354 check md5sum for all.zip
            _ = await _fetch_and_save(
                all_zip_url,
                archive_path,
                timeout,
                dest_mode=file_mode,
                compress=False,
            )

            # extract files to new dir with right permissions & verify
            try:
                with zipfile.ZipFile(archive_path, "r") as archive:
                    # NOTE: this also verifies crc-32 checksum for files
                    archive.extractall(new_path)

                # set mode
                for root, directories, filenames in os.walk(new_path):
                    for directory in directories:
                        os.chmod(os.path.join(root, directory), dir_mode)
                    for filename in filenames:
                        os.chmod(os.path.join(root, filename), file_mode)

                # verify against included description.json
                self.validate(new_path)

                # create symlink to new dir, replace *live* with the symlink
                old_path = Index._replace_live_with_new_dir(
                    new_path, live_path
                )
            except (
                EOFError,
                IntegrityError,
                OSError,
                zipfile.BadZipfile,
                zipfile.LargeZipFile,
            ) as e:
                raise UpdateError(str(e)) from e

            # no exception, clear the rollback stack
            rollback_stack.pop_all()

        # cleanup: remove old dir & new all.zip
        Index._all_zip_cleanup(
            old_path, archive_path, remove_files=bool(old_path)
        )
        return True  # updated

    @staticmethod
    def _replace_live_with_new_dir(
        new_path: pathlib.Path, live_path: pathlib.Path
    ) -> pathlib.Path:
        """Replace *live_path* with *new_path*.

        Return *old_path*

        :raises: OSError
        """
        new_live_path = new_path.with_name(new_path.name + "live")
        moved_path = None
        with ExitStack() as rollback_stack:
            new_live_path.symlink_to(new_path, target_is_directory=True)
            rollback_stack.callback(new_live_path.unlink)
            # save the path to old dir for the cleanup
            old_path = (
                live_path.resolve(strict=False)
                if live_path.is_symlink()
                else None
            )
            # switch to the new version
            # NOTE: nothing until this point touched old version;
            #       the rename should be atomic
            #       (paths are on the same partition)
            for last in range(2):  # pragma: no branch
                try:
                    new_live_path.rename(live_path)
                    break
                except IsADirectoryError:
                    if last:  # give up (keep old)
                        raise  # pragma: no cover

                    # live_path is a directory
                    # (old agent version or tests)
                    # move it so that the rename above could happen
                    if not live_path.is_symlink():  # pragma: no branch
                        # use unique to the current update name
                        moved_path = new_live_path.with_name(
                            new_live_path.name + ".live-moved"
                        )
                        logger.info(
                            "Moving %s [live] to %s,"
                            " to rename %s to it [live]",
                            live_path,
                            moved_path,
                            new_live_path,
                        )
                        live_path.replace(moved_path)
                        # if enabling new_live fails the 2nd time,
                        # try to move back, to restore old dir
                        rollback_stack.callback(moved_path.replace, live_path)

            if moved_path is not None:
                shutil.rmtree(moved_path, ignore_errors=True)

            # no exception, clear the rollback stack
            rollback_stack.pop_all()

        return old_path

    async def _run_update(self, timeout) -> bool:
        """
        Run update, return whether updated.

        :raise UpdateError: if OSError or http error or
                            integrity check error (got wrong data from
                            the server)
        """
        url = self._descriptionfile_url(self.type)
        as_json = await _fetch_json(url, timeout=timeout)
        to_update, to_remove = self._calculate_changes(_items(as_json))
        need_update = to_update or to_remove
        if not need_update:
            logger.info("updating %s: nothing to update.", self.type)
            self._touch()  # postpone the next try for FilesUpdate.PERIOD
            return False  # not updated

        # perform atomic update
        live_path = pathlib.Path(self.files_path(self.type))
        # note: it is ok if the symlink changes before .resolve() is called
        old_path = (
            live_path.resolve(strict=False) if live_path.is_symlink() else None
        )
        new_path = Index._generate_new_path(live_path)

        # make new download dir
        with ExitStack() as rollback_stack:
            self._makedirs(
                new_path, self._PERMS[self.type]["dir"], exist_ok=False
            )
            rollback_stack.callback(
                shutil.rmtree, new_path, ignore_errors=True
            )

            # copy all files from *old* dir to *new* dir except those
            # that needs updating
            from_path = (
                old_path if old_path and old_path.is_dir() else live_path
            )
            if from_path.is_dir():
                await Index._copytree(
                    from_path,
                    new_path,
                    to_remove.union(
                        self.localfilepath(item.url) for item in to_update
                    ),
                )

            # download *to_update* files to *new_path*
            await self._update_files(new_path, to_update, timeout=timeout)

            try:
                # write description.json
                with _open_with_mode(
                    new_path / "description.json",
                    self._PERMS[self.type]["file"],
                ) as file:
                    file.write(json.dumps(as_json).encode())

                # verify against included description.json
                self.validate(new_path)

                # create symlink to new dir, replace *live* with the symlink
                old_path = self._replace_live_with_new_dir(new_path, live_path)
            except (IntegrityError, OSError) as e:
                raise UpdateError(str(e)) from e

            # no exception, clear the rollback stack
            rollback_stack.pop_all()

        # cleanup: remove old path on success
        if old_path and old_path.is_dir():
            logger.info(
                "Removing old path on file by file update: %s", old_path
            )
            shutil.rmtree(old_path, ignore_errors=True)

        return True  # updated

    @staticmethod
    async def _copytree(
        from_dir: os.PathLike, to_dir: os.PathLike, ignored_paths: Set[str]
    ) -> None:
        """Copy *from_dir* to *to_dir* except for *ignored_paths*."""

        def ignore_names(path, names):
            """Return  names that should not be copied."""
            assert isinstance(os.fspath(path), str)  # no bytes here
            return frozenset(
                name
                for name in names
                if os.path.join(path, name) in ignored_paths
            )

        await to_thread(
            shutil.copytree,
            from_dir,
            to_dir,
            symlinks=True,
            ignore=ignore_names,
            dirs_exist_ok=True,
        )

    def localfilepath(self, url: str) -> str:
        """Return a local file path corresponding to URL."""
        url_relpath = os.path.relpath(
            urlparse(url).path, self._URL_PATH_PREFIX
        )
        type_path = self._PATHS[self.type]
        assert (
            pathlib.Path(type_path) in pathlib.Path(url_relpath).parents
        ), "url ({}) does not fit file path ({})".format(url, type_path)

        relative_path = os.path.relpath(url_relpath, type_path)
        return os.path.join(self.files_path(self.type), relative_path)

    def _touch(self) -> None:
        """Update mtime of description.json file so it is fresh."""
        try:
            path = self._descriptionfile_path()
            if os.path.isfile(path):  # pragma: no branch
                os.utime(path)
        except OSError as e:  # pragma: no cover
            logger.warning(str(e))

    async def _run_hooks(self, is_updated) -> None:
        for hook in chain(self._HOOKS[self.type], [default_hook]):
            try:
                await hook(self, is_updated)
            except Exception as e:
                logger.exception("hook %s error: %s", hook, e)
        logger.info(
            "%s files update finished%s",
            self.type,
            " (not updated)" * (not is_updated),
        )

    async def update(self, force=False) -> None:
        """Run update for the current `type` of files.

        Normally update is performed when either is true:

        * index is never been fetched (description.json missing or broken);
        * last update was performed longer than configured period of time ago;
        * some local files are missing or have wrong content (md5 hash differs
          from description.json).

        If force is True then update is performed unconditionally.

        Raises asyncio.TimeoutError, UpdateError.
        """
        timeout = config.FilesUpdate.TIMEOUT  # total timeout
        if not force and not await self.is_update_needed(timeout):
            logger.info(
                "%s was updated less than %s minutes ago.",
                self.type,
                int(config.FilesUpdate.PERIOD // 60),
            )
            await self._run_hooks(is_updated=False)
            return
        all_zip = self._is_blank and self._ALL_ZIP_SUPPORT[self.type]
        file_by_file = not all_zip
        if all_zip:
            log_str = "all.zip"
            logger.info("Updating %s files via %s", self.type, log_str)
            # Download updates using all.zip in case of empty or
            # corrupted description.json.
            # Initially we try to download updates using all.zip, if
            # error happened - download file by file.
            try:
                updated = await asyncio.wait_for(
                    self._run_update_all_zip(
                        config.FilesUpdate.SOCKET_TIMEOUT
                    ),
                    timeout,
                )
                if updated:
                    logger.info("Updated %s using %s", self.type, log_str)
            except (asyncio.TimeoutError, UpdateError) as e:
                logger.error(
                    "%s update error via %s: %s", self.type, log_str, e
                )
                file_by_file = True
        if file_by_file:
            log_str = "file by file download"
            logger.info("Updating %s files via %s", self.type, log_str)
            try:
                updated = await asyncio.wait_for(
                    self._run_update(config.FilesUpdate.SOCKET_TIMEOUT),
                    timeout,
                )
                if updated:
                    logger.info("Updated %s using %s", self.type, log_str)
            except (asyncio.TimeoutError, UpdateError) as e:
                logger.error(
                    "%s update error via %s: %s", self.type, log_str, e
                )
                await self._run_hooks(is_updated=False)
                # Ignore errors only for non-essential files
                if self.type in self._ESSENTIAL_TYPES:
                    raise e
                else:
                    return
        await self._run_hooks(is_updated=updated or force)

    @classmethod
    async def update_all(
        cls, only_type: Optional[str] = None, force=False, only_essential=False
    ) -> None:
        """Run update for all registered `types` of files.

        Raises asyncio.TimeoutError, UpdateError.
        """
        if only_type:
            index = cls(only_type, integrity_check=False)
            async with cls.locked(only_type):
                await index.update(force)
        elif only_essential:
            logger.info("Updating essential files")
            for type_ in cls._ESSENTIAL_TYPES:
                index = cls(type_, integrity_check=False)
                async with cls.locked(type_):
                    await index.update(force)
        else:
            logger.info("Updating all files")
            for type_ in cls._TYPES:
                index = cls(type_, integrity_check=False)
                async with cls.locked(type_):
                    await index.update(force)

    @classmethod
    def add_hook(cls, type_: str, hook) -> None:
        """Add a hook for type_ to be called after successful update."""
        cls._HOOKS[type_].add(hook)


def configure() -> None:
    """Register required file types."""
    Index.add_type(EULA, "eula/v1", 0o770, 0o660, all_zip=False)
    Index.add_type(SIGS, "sigs/v1", 0o775, 0o644, all_zip=True)
    Index.add_type(
        REALTIME_AV_CONF,
        "realtime-av-conf/v1",
        0o770,
        0o660,
        all_zip=False,
    )


update = Index.update_all
essential_files_exist = Index.essential_files_exist


async def update_and_log_error(
    only_type: Optional[str] = None, force=False
) -> None:
    """Run files.update and log Update/TimeoutErrors."""
    try:
        return await Index.update_all(only_type, force)
    except (asyncio.TimeoutError, UpdateError) as err:
        logger.error(
            "Failed to update files [%s] with error: %s", only_type, err
        )


async def update_all_no_fail_if_files_exist():
    """Update all files. Don't fail if essential files exist."""
    try:
        return await Index.update_all(only_essential=True)
    except (asyncio.TimeoutError, UpdateError) as err:
        if await Index.essential_files_exist():
            logger.error(
                "Failed to update files [essential files exist]: %s", err
            )
        else:  # re-raise
            if isinstance(err, asyncio.TimeoutError):
                raise UpdateError from err  # wrap
            else:
                raise

Zerion Mini Shell 1.0