diff --git a/Dockerfile b/Dockerfile index a174283..1f753b9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ WORKDIR /app COPY requirements.txt . RUN apt-get update && apt-get install -y --no-install-recommends gcc libffi-dev libssl-dev -RUN pip install asyncio redis +RUN pip install asyncio aiohttp redis[hiredis] RUN pip install --no-cache-dir -r requirements.txt # Stage 2: Final stage diff --git a/handlers/handle_private.py b/handlers/handle_private.py index e094c89..e66dd15 100644 --- a/handlers/handle_private.py +++ b/handlers/handle_private.py @@ -2,7 +2,7 @@ from bot.config import FEEDBACK_CHAT_ID from bot.announce import edit_announce from bot.api import telegram_api import logging -from utils.store import get_all_removed, get_average_toxic +from utils.store import get_all_pattern, get_average_pattern logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -29,12 +29,15 @@ async def handle_private(msg, state): state['welcome'] = False return elif text.startswith('/toxic'): - toxic_score = await get_average_toxic(msg) + toxic_pattern = f"toxic:{uid}:{cid}:*" + toxic_score = await get_average_toxic(toxic_pattern) text = f"Средняя токсичность сообщений: {toxic_score}%" - await telegram_api("sendMessage", chat_id=uid, reply_to_message_id=msg.get("message_id"), text=text) + mid = msg.get("message_id") + await telegram_api("sendMessage", chat_id=uid, reply_to_message_id=mid, text=text) return elif text == '/removed': - removed_messages = await get_all_removed(uid) + removed_pattern = f"removed:{uid}:*" + removed_messages = await get_all_pattern(removed_pattern) if removed_messages: await telegram_api("sendMessage", chat_id=uid, text="\n\n".join(removed_messages)) return diff --git a/handlers/messages_routing.py b/handlers/messages_routing.py index 934ce8e..26860ff 100644 --- a/handlers/messages_routing.py +++ b/handlers/messages_routing.py @@ -1,6 +1,6 @@ import logging import math -from utils.store import redis, get_average_toxic +from utils.store import redis, get_average_pattern from bot.api import telegram_api from bot.config import FEEDBACK_CHAT_ID from nlp.toxicity_detector import detector @@ -44,9 +44,10 @@ async def messages_routing(msg, state): reply_to_msg_id = int(latest_toxic_message_id) # count average between all of messages - toxic_score = await get_average_toxic(msg) + toxic_pattern = f"toxic:{cid}:{uid}:*" + toxic_score = await get_average_pattern(toxic_pattern) - # + # current mesasage toxicity if reply_to_msg_id: one_score = await redis.get(f"toxic:{cid}:{uid}:{reply_to_msg_id}") if one_score: diff --git a/utils/store.py b/utils/store.py index be8efc7..96f7943 100644 --- a/utils/store.py +++ b/utils/store.py @@ -7,11 +7,86 @@ import logging logger = logging.getLogger('store') logging.basicConfig(level=logging.DEBUG) -# Connect to Redis -redis = aredis.Redis.from_url(REDIS_URL) + +class RedisService: + def __init__(self, uri=REDIS_URL): + self._uri: str = uri + self.pubsub_channels = [] + self._client = None + + async def connect(self): + self._client = aredis.Redis.from_url(self._uri, decode_responses=True) + + async def disconnect(self): + if self._client: + await self._client.close() + + async def execute(self, command, *args, **kwargs): + if self._client: + try: + logger.debug(f"{command}") # {args[0]}") # {args} {kwargs}") + for arg in args: + if isinstance(arg, dict): + if arg.get("_sa_instance_state"): + del arg["_sa_instance_state"] + r = await self._client.execute_command(command, *args, **kwargs) + # logger.debug(type(r)) + # logger.debug(r) + return r + except Exception as e: + logger.error(e) + + async def subscribe(self, *channels): + if self._client: + async with self._client.pubsub() as pubsub: + for channel in channels: + await pubsub.subscribe(channel) + self.pubsub_channels.append(channel) + + async def unsubscribe(self, *channels): + if not self._client: + return + async with self._client.pubsub() as pubsub: + for channel in channels: + await pubsub.unsubscribe(channel) + self.pubsub_channels.remove(channel) + + async def publish(self, channel, data): + if not self._client: + return + await self._client.publish(channel, data) + + async def set(self, key, data, ex=None): + # Prepare the command arguments + args = [key, data] + + # If an expiration time is provided, add it to the arguments + if ex is not None: + args.append("EX") + args.append(ex) + + # Execute the command with the provided arguments + await self.execute("set", *args) + + async def scan_iter(self, pattern='*'): + """Asynchronously iterate over keys matching the given pattern.""" + cursor = '0' + while cursor != 0: + cursor, keys = await self._client.scan(cursor=cursor, match=pattern) + for key in keys: + yield key + + async def get(self, key): + return await self.execute("get", key) -async def get_all_removed(uid): +redis = RedisService() + +__all__ = ["redis"] + + + +async def get_all_pattern(uid): pattern = f"removed:{uid}:*" # Create a dictionary to hold the keys and values @@ -27,10 +102,7 @@ async def get_all_removed(uid): return texts -async def get_average_toxic(msg): - uid = msg['from']['id'] - cid = msg['chat']['id'] - pattern = f"toxic:{cid}:{uid}:*" +async def get_average_pattern(pattern): scores = [] scoring_msg_id = 0 async for key in redis.scan_iter(pattern):