ok

Mini Shell

Direktori : /opt/imunify360/venv/lib64/python3.11/site-packages/im360/api/
Upload File :
Current File : //opt/imunify360/venv/lib64/python3.11/site-packages/im360/api/ips.py

import asyncio
import time
from abc import ABCMeta, abstractclassmethod
from functools import partial, wraps
from ipaddress import IPv4Network, IPv6Network
from typing import List, Union

from defence360agent.internals.global_scope import g
from defence360agent.model import instance
from defence360agent.model.simplification import run_in_executor
from im360.contracts.config import Protector
from defence360agent.contracts.messages import MessageType
from im360.internals import geo
from im360.internals.core.ipset.country import IPSetCountry
from im360.internals.core.ipset.ip import IPSet
from im360.internals.core.ipset.port import IPSetIgnoredByPort, IPSetPort
from im360.model.country import CountryList, Country as CountryModel
from im360.model.firewall import BlockedPort, IgnoredByPort, IPList
from im360.utils.net import pack_ip_network


def postprocess_records(func):
    async def wrapper(self, *args, **kwargs):
        affected, not_affcted = await func(self, *args, **kwargs)
        return await self._postprocess_records(affected, not_affcted)

    return wrapper


def with_rule_edit_lock(coro):
    @wraps(coro)
    async def wrapper(*args, **kwargs):
        async with Protector.RULE_EDIT_LOCK:
            return await coro(*args, **kwargs)

    return wrapper


class API(metaclass=ABCMeta):
    ipset = None

    @classmethod
    @abstractclassmethod
    def _create_record(cls, *args, **kwawrgs):
        pass

    @classmethod
    @abstractclassmethod
    def _delete_record(cls, *args, **kwargs):
        pass

    @classmethod
    @abstractclassmethod
    async def _edit(cls, *args, **kwargs):
        pass

    @classmethod
    @with_rule_edit_lock
    async def _add(cls, *args, **kwargs):
        _, created = await run_in_executor(
            asyncio.get_event_loop(),
            lambda: cls._create_record(*args, **kwargs),
        )

        if created:
            assert cls.ipset, "IPSet instance is missing for this API"
            await cls.ipset.block(*args, **kwargs)
        return created

    @classmethod
    @with_rule_edit_lock
    async def _delete(cls, *args, **kwargs):
        num_deleted = await run_in_executor(
            asyncio.get_event_loop(),
            lambda: cls._delete_record(*args, **kwargs),
        )

        if num_deleted:
            assert cls.ipset, "IPSet instance is missing for this API"
            await cls.ipset.unblock(*args, **kwargs)
        return num_deleted != 0

    @classmethod
    async def _postprocess_records(cls, affected, not_affected):
        """Add some fields in to result list, e.g 'listname'"""
        return affected, not_affected

    @classmethod
    @postprocess_records
    async def block(cls, items, *args, **kwargs):
        return await split_result(cls._add, items, *args, **kwargs)

    @classmethod
    @postprocess_records
    async def unblock(cls, items, *args, **kwargs):
        return await split_result(cls._delete, items, *args, **kwargs)

    @classmethod
    @postprocess_records
    async def edit(cls, items, *args, **kwargs):
        return await split_result(cls._edit, items, *args, **kwargs)


class IPApi(API):
    ipset = IPSet()

    @staticmethod
    def _create_record(
        ip,
        listname,
        imported_from=None,
        comment=None,
        full_access=False,
        expiration=0,
        **kwargs
    ):
        assert listname in [IPList.BLACK, IPList.WHITE, IPList.GRAY]
        with instance.db.transaction(), geo.reader() as geo_reader:
            IPList.delete_expired(ip=ip)

            if kwargs.get("manual"):
                # for manual lists we want that IP should be in one list
                # in BLACK or in WHITE, so here is special case for that
                for list_ in [IPList.WHITE, IPList.BLACK]:
                    try:
                        obj = IPList.get(ip=ip, listname=list_)
                    except IPList.DoesNotExist:
                        pass
                    else:
                        # return obj and False (not created), so it will be
                        # processed later
                        return obj, False, None

            # remove expired subnets
            keep_manual = kwargs.pop("keep_manual_expired_subnets", None)
            unblocklist = []
            for subnet, list_, subnet_expiration in IPList.find_net_members(
                ip,
                listname=[
                    IPList.GRAY_SPLASHSCREEN,
                    IPList.GRAY,
                    IPList.BLACK,
                ],
                expired_by=expiration,
                include_itself=True,
                manual=False if keep_manual else None,
            ):
                if not (  # don't delete exact matches for listname
                    subnet == ip
                    and listname == list_
                    and subnet_expiration == expiration
                ):
                    unblocklist.append((subnet, list_))
                    IPApi._delete_record(subnet, list_)

            # add
            country = geo_reader.get_id(ip)

            return (
                *IPList.create_or_get(
                    ip=ip,
                    listname=listname,
                    imported_from=imported_from,
                    comment=comment,
                    country=country,
                    full_access=full_access,
                    expiration=expiration,
                    **kwargs
                ),
                MessageType.BlockUnblockList(
                    blocklist={}, unblocklist=unblocklist
                ),
            )

    @staticmethod
    def _delete_record(ip, listname):
        if isinstance(listname, str):  # got a single list name
            listname = [listname]  # convert to a list
        return IPList.delete_from_list(ip=ip, listname=listname)

    @classmethod
    @with_rule_edit_lock
    async def _delete(cls, ip, listname):
        num_deleted = await run_in_executor(
            asyncio.get_event_loop(), lambda: cls._delete_record(ip, listname)
        )

        if num_deleted:
            assert cls.ipset, "IPSet instance is missing for this API"
            await cls.ipset.unblock(ip, listname)
        return num_deleted != 0

    @classmethod
    @with_rule_edit_lock
    async def _add(cls, ip, listname, *args, **kwargs):
        obj, created, unblock_ips = await run_in_executor(
            asyncio.get_event_loop(),
            lambda: IPApi._create_record(ip, listname, *args, **kwargs),
        )
        assert not (unblock_ips and unblock_ips.blocklist)

        if created:
            # "control" add to ipset list
            expiration = getattr(obj, "expiration", 0)
            if not isinstance(expiration, int):
                raise TypeError(
                    "expiration must be integer, got {}".format(
                        type(expiration)
                    )
                )
            if expiration:
                timeout = int(obj.expiration - time.time())
                if timeout <= 0:
                    return
                kwargs["timeout"] = timeout
            await cls.ipset.block(ip, listname, *args, **kwargs)
        await cls._unblock_ips(unblock_ips)
        return created

    @classmethod
    async def _unblock_ips(cls, ips: "BlockUnblockList"):
        """Unblock *ips* from ipset/webshield."""
        if ips and ips.unblocklist:
            for ip, listname in ips.unblocklist:
                await cls.ipset.unblock(ip, listname)

    @classmethod
    @with_rule_edit_lock
    async def _move(cls, row, listname, full_access=False):
        """
        https://gerrit.cloudlinux.com/#/c/61260/22/src/handbook/message_processing/client_move.py

        * shouldn't move to GRAY* lists
        * do not move if already in list

        * remove lists which exactly same and leave only one record with
          IPList.NEVER expiration
        """

        ip, src = row["ip"], row["listnames"]
        if listname in src:
            # unable to move between the same lists
            return 0

        num_updated = await run_in_executor(
            asyncio.get_event_loop(),
            partial(
                IPList.move,
                ip=ip,
                dest=listname,
                src=src,
                full_access=full_access,
            ),
        )

        if num_updated:
            for src_listname in src:
                await cls.ipset.unblock(ip, src_listname)
            await cls.ipset.block(ip, listname, full_access=full_access)

        return num_updated != 0

    @classmethod
    @with_rule_edit_lock
    async def _edit(
        cls,
        ip,
        listname,
        comment=None,
        full_access=None,
        expiration=None,
        scope=None,
        allow_move=False,
        comment_autogenerated=False,
    ):
        """Implement manual "[ip]list ip edit" command"""
        num_updated = 0
        fields = dict()
        if comment is not None and not comment_autogenerated:
            fields["comment"] = comment
        if full_access is not None:
            fields["full_access"] = full_access
        if expiration is not None:
            fields["expiration"] = expiration
        if scope is not None:
            fields["scope"] = scope
        if not fields:
            return num_updated

        fields["manual"] = True
        fields["captcha_passed"] = False

        num_updated, unblock_ips, changed_record = await run_in_executor(
            None,
            partial(
                cls._edit_record,
                ip=ip,
                listname=listname,
                fields=fields,
                allow_move=allow_move,
            ),
        )
        # TODO: consider unifying block/unblock calls via [un]blocklist
        assert not (unblock_ips and unblock_ips.blocklist)
        if num_updated and (full_access is not None or expiration is not None):
            kwargs = dict(full_access=full_access)
            if "expiration" in fields:
                kwargs["expiration"] = fields["expiration"]
            if expiration is not None:
                if not isinstance(expiration, int):
                    raise TypeError(
                        "expiration must be integer, got {}".format(
                            type(expiration)
                        )
                    )
                if expiration:
                    timeout = int(expiration - time.time())
                    if timeout <= 0:
                        # IP already expired, marking it as not updated
                        return 0
                else:
                    timeout = 0
                kwargs["timeout"] = timeout
            # need to add IP into new IPSet
            await cls.ipset.unblock(ip, changed_record.listname)
            await cls.ipset.block(ip, listname, **kwargs)
        await cls._unblock_ips(unblock_ips)
        return num_updated

    @staticmethod
    def _edit_record(
        ip, listname, fields, allow_move=False
    ) -> "Tuple[int,Optional[BlockUnblockList]]":  # noqa
        """'{black,white}list ip edit' rpc command db part implemenation."""
        assert listname in [IPList.BLACK, IPList.WHITE, IPList.GRAY]
        with instance.db.transaction():
            unblocklist = []
            try:
                if allow_move:
                    net, mask, version = pack_ip_network(ip)
                    records = list(
                        IPList.select()
                        .where(
                            IPList.network_address == net,
                            IPList.netmask == mask,
                            IPList.version == version,
                            IPList.listname.in_(
                                [IPList.BLACK, IPList.WHITE, IPList.GRAY]
                            ),
                        )
                        .execute()
                    )
                    if len(records) == 2:
                        # remove record with another listname and
                        # move with listname equal to dest listname
                        for rec in records:
                            if rec.listname != listname:
                                IPList.delete_from_list(
                                    ip=ip, listname=[rec.listname]
                                )
                                unblocklist.append((ip, rec.listname))
                    if len(records) == 1:
                        record = records[0]
                    else:
                        # TODO TBD
                        # we here because add method return some supernet
                        raise IPList.DoesNotExist()
                else:
                    record = IPList.get(ip=ip, listname=listname)
            except IPList.DoesNotExist:
                return 0, None, None  # can't edit a non-existing record
            else:
                # remove expiring less important subnets
                new_expiration = fields.get("expiration", 1)
                if record.lives_less(new_expiration):
                    for subnet, list_, _ in IPList.find_net_members(
                        ip,
                        listname=IPList.lists_with_less_or_equal_priorities(
                            record.listname
                        ),
                        expired_by=new_expiration,
                        include_itself=True,
                    ):
                        if not (  # don't delete exact matches for listname
                            subnet == ip and record.listname == list_
                        ):
                            unblocklist.append((subnet, list_))
                            IPApi._delete_record(subnet, list_)

                # update fields
                if allow_move and listname != record.listname:
                    fields["listname"] = listname
                # note: use the update query for atomicity
                num_updated = (
                    IPList.update(**fields)
                    .where(
                        (IPList.network_address == record.network_address)
                        & (IPList.netmask == record.netmask)
                        & (IPList.version == record.version)
                        & (IPList.listname == record.listname)
                    )
                    .execute()
                )
                return (
                    num_updated,
                    MessageType.BlockUnblockList(
                        blocklist={}, unblocklist=unblocklist
                    ),
                    record,
                )

    @classmethod
    async def _postprocess_records(cls, affected, not_affected):
        """
        Adds listname to every IP
        :param list of dicts affected:
        :param list of dicts not_affected:
        :return list of dicts, list of dicts
        """
        not_affected_processed = []
        for item in not_affected:
            listname = await run_in_executor(
                asyncio.get_event_loop(),
                IPList.effective_list,
                (
                    item["rec"]
                    if isinstance(item["rec"], (IPv4Network, IPv6Network))
                    else item["rec"]["ip"]
                ),
            )
            # listname here could be None if record already deleted or expired
            item.update(listname=listname)
            not_affected_processed.append(item)
        return affected, not_affected_processed

    @classmethod
    @postprocess_records
    async def move(cls, items, *args, **kwargs):
        return await split_result(cls._move, items, *args, **kwargs)


class IPApiWithIdempotentAdd(IPApi):
    """
    another class to work with iplists,
     `_add` method will do same things in one place
      that parent class _add/_edit/_move do.
    """

    @classmethod
    async def _add(cls, ip, listname, *args, **kwargs):
        comment_autogenerated = kwargs.pop("comment_autogenerated", False)
        created = await super()._add(ip, listname, *args, **kwargs)
        if created:
            return created
        # default value for `full_access` in case move is false,
        # see src/asyncclient/im360/simple_rpc/schema/lists/ip/white.yaml
        if kwargs.get("full_access") is None and listname == IPList.WHITE:
            kwargs["full_access"] = False
        kwargs.pop("manual", None)
        kwargs.pop("keep_manual_expired_subnets", None)
        return await super()._edit(
            ip,
            listname,
            allow_move=True,
            comment_autogenerated=comment_autogenerated,
            *args,
            **kwargs
        )


class MockedCountryIpset(IPSetCountry):
    async def unblock(self, *_, **__):
        pass


class CountryAPI(API):
    ipset = MockedCountryIpset()

    @staticmethod
    def _delete_record(country, listname):
        country_obj = CountryModel.get(code=country)

        return CountryList.delete_country(country_obj, listname)

    @staticmethod
    def _create_record(country, listname, comment=None, **kwargs):
        country_obj = CountryModel.get(code=country)

        return CountryList.create_or_get(
            country=country_obj.id, listname=listname, comment=comment
        )

    @classmethod
    async def _edit(cls, country, comment):
        return await run_in_executor(
            asyncio.get_event_loop(),
            lambda: CountryList.update(comment=comment)
            .where(CountryList.country == CountryModel.get(code=country).id)
            .execute(),
        )

    @classmethod
    async def _postprocess_records(cls, affected, not_affected):
        """
        Adds listname to every Country
        :param list of dicts affected:
        :param list of dicts not_affected:
        :return list of dicts, list of dicts
        """
        not_affected_processed = []
        for item in not_affected:
            listname = await run_in_executor(
                asyncio.get_event_loop(),
                lambda: CountryList.get_listname(item["rec"]),
            )
            item.update(listname=listname)
            not_affected_processed.append(item)
        return affected, not_affected_processed


class PortAPI(API):
    ipset = IPSetPort()

    @staticmethod
    def _delete_record(item):
        port, proto = item
        return (
            BlockedPort.delete()
            .where((BlockedPort.port == port) & (BlockedPort.proto == proto))
            .execute()
        )

    @staticmethod
    def _create_record(item, comment=None):
        port, proto = item

        return BlockedPort.create_or_get(
            port=port, proto=proto, comment=comment
        )

    @classmethod
    async def _edit(cls, item, comment=None):
        port, proto = item

        return await run_in_executor(
            asyncio.get_event_loop(),
            lambda: BlockedPort.update(comment=comment)
            .where((BlockedPort.port == port) & (BlockedPort.proto == proto))
            .execute(),
        )


class IgnoredByPortAPI(API):
    ipset = IPSetIgnoredByPort()

    @staticmethod
    def _delete_record(ip, port, proto):
        port = BlockedPort.get(port=port, proto=proto)

        net, mask, version = pack_ip_network(ip)
        return (
            IgnoredByPort.delete()
            .where(
                (IgnoredByPort.port_proto == port)
                & (IgnoredByPort.network_address == net)
                & (IgnoredByPort.netmask == mask)
                & (IgnoredByPort.version == version)
            )
            .execute()
        )

    @staticmethod
    def _create_record(ip, port, proto, comment=None):
        port_proto = BlockedPort.get(port=port, proto=proto)
        with geo.reader() as geo_reader:
            country = geo_reader.get_id(ip)
            return IgnoredByPort.create_or_get(
                port_proto=port_proto, ip=ip, comment=comment, country=country
            )

    @classmethod
    async def _edit(cls, ip, port, proto, comment=None):
        port_proto = BlockedPort.get(port=port, proto=proto)

        net, mask, version = pack_ip_network(ip)
        return await run_in_executor(
            asyncio.get_event_loop(),
            lambda: IgnoredByPort.update(comment=comment)
            .where(
                (IgnoredByPort.network_address == net)
                & (IgnoredByPort.netmask == mask)
                & (IgnoredByPort.version == version)
                & (IgnoredByPort.port_proto == port_proto)
            )
            .execute(),
        )


async def split_result(f, records, *args, **kwargs):
    """
    Split result to affected/not affected records

    :param f: executable object
    :param list of str records: original iterate object
    :return list of str affected,
            list of dicts not_affected:
    """
    assert isinstance(
        records, (list, tuple)
    ), 'items should be list or tuple, instead - "{}"'.format(records)

    affected, not_affected = [], []
    for rec in records:
        is_affected = await f(rec, *args, **kwargs)
        if is_affected:
            affected.append(rec)
        else:
            not_affected.append({"rec": rec})
    return affected, not_affected


class GroupIPSyncSender:
    def __init__(self):
        self._to_be_sent_to_correlation = None

    async def send(self, action):
        if self._to_be_sent_to_correlation:
            if action == "add":
                data = [
                    dict(
                        ip=ip_model.ip,
                        expiration=ip_model.expiration,
                        list=ip_model.listname,
                        full_access=ip_model.full_access,
                        comment=ip_model.comment,
                    )
                    for ip_model in self._to_be_sent_to_correlation
                ]
            elif action == "del":
                data = [
                    dict(
                        ip=ip_model.ip,
                        list=ip_model.listname,
                    )
                    for ip_model in self._to_be_sent_to_correlation
                ]
            await g.sink.process_message(
                MessageType.GroupIPSync(
                    {
                        action: data,
                    }
                )
            )

    async def collect(self, items: List[Union[IPv4Network, IPv6Network]]):
        self._to_be_sent_to_correlation = await run_in_executor(
            asyncio.get_event_loop(),
            lambda: IPList.fetch_for_group_sync(items),
        )
        return self

    def filter(self, items: List[Union[IPv4Network, IPv6Network]]):
        self._to_be_sent_to_correlation = [
            item
            for item in self._to_be_sent_to_correlation
            if item.ip_network in items
        ]
        return self

Zerion Mini Shell 1.0