Source code for py_abac.storage.redis.storage

"""
    Redis policy storage
"""

import json
import logging
from itertools import islice
from typing import Generator, Union

from redis import Redis

from ..base import Storage
from ...exceptions import PolicyExistsError
from ...policy import Policy

LOG = logging.getLogger(__name__)

DEFAULT_HASH_KEY = "py_abac_policies"


[docs]class RedisStorage(Storage): """ Redis policy storage backend. :param client: redis client. :param hash_key: hash key under which policies are stored in database. """ def __init__(self, client: Redis, hash_key: str = None): self.client = client self._hash = hash_key or DEFAULT_HASH_KEY
[docs] def add(self, policy: Policy): """ Store a policy """ rvalue = self.client.hsetnx(self._hash, policy.uid, self.__to_policy_str(policy)) if rvalue == 0: LOG.error('Error trying to create already existing policy with UID=%s.', policy.uid) raise PolicyExistsError(policy.uid) LOG.info('Added Policy: %s', policy)
[docs] def get(self, uid: str) -> Union[Policy, None]: """ Get specific policy """ policy_str = self.client.hget(self._hash, uid) if not policy_str: return None return self.__to_policy(policy_str)
[docs] def get_all(self, limit: int, offset: int) -> Generator[Policy, None, None]: """ Retrieve all the policies within a window .. note: Redis doesn't guarantee the exact number of elements returned thus this method gets all policies from Redis and manually slices the list. """ self._check_limit_and_offset(limit, offset) rvalue = self.client.hgetall(self._hash) policies = islice(rvalue.values(), offset, offset + limit) for policy_str in policies: yield self.__to_policy(policy_str)
[docs] def get_for_target( self, subject_id: str, resource_id: str, action_id: str ) -> Generator[Policy, None, None]: """ Get all policies for given target IDs. .. note: Currently all policies are returned for evaluation by PDP. """ # TODO: Create topologically sorted graph index for filtered retrieval. rvalue = self.client.hgetall(self._hash) for uid in rvalue: policy_str = rvalue[uid] yield self.__to_policy(policy_str)
[docs] def update(self, policy: Policy): """ Update a policy NOTE: The lua script is used to make sure an update operation occurs instead of upsert. """ uid = policy.uid lua = \ """ if redis.call('HEXISTS', KEYS[1], ARGV[1]) == 1 then return redis.call('HSET', KEYS[1], ARGV[1], ARGV[2]) end """ update_policy = self.client.register_script(lua) rvalue = update_policy(keys=[self._hash], args=[uid, self.__to_policy_str(policy)]) if rvalue is not None: LOG.info('Updated Policy with UID=%s. New value is: %s', uid, policy)
[docs] def delete(self, uid: str): """ Delete a policy """ rvalue = self.client.hdel(self._hash, uid) if rvalue != 0: LOG.info('Deleted Policy with UID=%s.', uid)
@staticmethod def __to_policy(policy_str: bytes) -> Policy: """ Converts stored policy string to policy object. """ policy_json = json.loads(policy_str.decode("utf-8")) return Policy.from_json(policy_json) @staticmethod def __to_policy_str(policy: Policy) -> str: """ Converts policy object to string for storage on Redis. """ policy_json = policy.to_json() return json.dumps(policy_json)