From a7b1925e8dabd2b2cbe0c3dbc0c8f17991a0632b Mon Sep 17 00:00:00 2001 From: Untone Date: Fri, 27 Sep 2024 09:23:55 +0300 Subject: [PATCH] ruffed --- bot/announce.py | 26 +++++++---- bot/api.py | 15 ++++--- handlers/handle_join_request.py | 34 ++++++++++++--- handlers/handle_private.py | 36 +++++++++------ handlers/messages_routing.py | 61 ++++++++++++++------------ main.py | 16 +++++-- {utils => nlp}/normalize.py | 75 +++++++++++++++++--------------- nlp/toxicity_detector.py | 17 +++++--- utils/store.py => state/redis.py | 38 ++-------------- state/scan.py | 34 +++++++++++++++ utils/graph.py | 3 +- utils/mention.py | 2 +- 12 files changed, 218 insertions(+), 139 deletions(-) rename {utils => nlp}/normalize.py (54%) rename utils/store.py => state/redis.py (72%) create mode 100644 state/scan.py diff --git a/bot/announce.py b/bot/announce.py index 3d36dbf..b58bb17 100644 --- a/bot/announce.py +++ b/bot/announce.py @@ -1,12 +1,13 @@ from bot.api import telegram_api -from utils.mention import mention, userdata_extract -from utils.store import redis +from utils.mention import userdata_extract +from state.redis import redis import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) + def get_newcomer_message(msg): lang = msg["from"].get("language_code", "ru") r = "хочет присоединиться к нам здесь" if lang == "ru" else " wants to join us here" @@ -27,27 +28,36 @@ async def show_announce(msg): userphotos_response = await telegram_api("getUserphotos", user_id=from_id) file_id = "" - if isinstance(userphotos_response, dict) and userphotos_response["ok"] and userphotos_response["result"]["total_count"] > 0: + if ( + isinstance(userphotos_response, dict) + and userphotos_response["ok"] + and userphotos_response["result"]["total_count"] > 0 + ): logger.info("showing button with photo") file_id = userphotos_response["result"]["photos"][0][0]["file_id"] - r = await telegram_api("sendPhoto", + r = await telegram_api( + "sendPhoto", chat_id=chat_id, file_id=file_id, caption=newcomer_message, - reply_to=mid + reply_to=mid, ) announce_msg_id = r.get("message_id") - await redis.set(f"announce:{chat_id}:{from_id}", announce_message_id) + await redis.set(f"announce:{chat_id}:{from_id}", announce_msg_id) async def edit_announce(msg): logger.info("editing announce") chat_id = str(msg["chat"]["id"]) from_id = str(msg["from"]["id"]) - mid = msg.get("message_id", "") caption = get_newcomer_message(msg) + msg.get("text").replace("/message ", "") announce_message_id = await redis.get(f"announce:{chat_id}:{from_id}") if announce_message_id: - r = await telegram_api("editMessageCaption", chat_id=chat_id, message_id=int(announce_message_id), caption=caption) + r = await telegram_api( + "editMessageCaption", + chat_id=chat_id, + message_id=int(announce_message_id), + caption=caption, + ) await redis.set(f"announce:{chat_id}:{from_id}", r.get("message_id")) diff --git a/bot/api.py b/bot/api.py index 33e469c..e95bf1f 100644 --- a/bot/api.py +++ b/bot/api.py @@ -5,7 +5,7 @@ from bot.config import BOT_TOKEN import logging # Create a logger instance -logger = logging.getLogger('bot.api') +logger = logging.getLogger("bot.api") logging.basicConfig(level=logging.DEBUG) api_base = f"https://api.telegram.org/bot{BOT_TOKEN}/" @@ -14,17 +14,20 @@ api_base = f"https://api.telegram.org/bot{BOT_TOKEN}/" async def telegram_api(endpoint: str, json_data=None, **kwargs): try: url = api_base + f"{endpoint}?{urlencode(kwargs)}" - is_polling = endpoint == 'getUpdates' - headers = {'Content-Type': 'application/json'} + is_polling = endpoint == "getUpdates" + headers = {"Content-Type": "application/json"} async with aiohttp.ClientSession() as session: url = api_base + f"{endpoint}?{urlencode(kwargs)}" if not is_polling: logger.info(f' >>> {url} {json_data if json_data else ""}') - async with session.get(url, data=json.dumps(json_data), headers=headers) as response: + async with session.get( + url, data=json.dumps(json_data), headers=headers + ) as response: data = await response.json() if not is_polling: - logger.info(f' <<< {data}') + logger.info(f" <<< {data}") return data - except Exception as ex: + except Exception: import traceback + traceback.print_exc() diff --git a/handlers/handle_join_request.py b/handlers/handle_join_request.py index a29aa50..82fe5b8 100644 --- a/handlers/handle_join_request.py +++ b/handlers/handle_join_request.py @@ -9,10 +9,30 @@ logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) -positive_reactions = ["👍", "❤", "🔥", "🥰", "👏", "🎉", "🙏", "👌", "🕊", "😍", "❤‍🔥", "🍓", "🍾", "💋", "😇", "🤝", "🤗", "💘", "😘"] +positive_reactions = [ + "👍", + "❤", + "🔥", + "🥰", + "👏", + "🎉", + "🙏", + "👌", + "🕊", + "😍", + "❤‍🔥", + "🍓", + "🍾", + "💋", + "😇", + "🤝", + "🤗", + "💘", + "😘", +] announced_message = { "ru": "Запрос на вступление опубликован в чате, как только вас узнают и отреагируют - она будет принята", - "en": "The join request is posted in the chat, once you are recognized and someone reacted to - it will be accepted" + "en": "The join request is posted in the chat, once you are recognized and someone reacted to - it will be accepted", } @@ -22,12 +42,14 @@ async def handle_join_request(join_request): lang = user.get("language_code", "ru") # показываем для FEEDBACK_CHAT - await telegram_api("sendMessage", chat_id=FEEDBACK_CHAT_ID, text="новая заявка от " + mention(user)) + await telegram_api( + "sendMessage", chat_id=FEEDBACK_CHAT_ID, text="новая заявка от " + mention(user) + ) # показываем анонс с заявкой await show_announce(join_request) # сообщаем пользователю, что опубликовали анонс его заявки - await telegram_api("sendMessage", chat_id=user['id'], text=announced_message[lang]) + await telegram_api("sendMessage", chat_id=user["id"], text=announced_message[lang]) async def handle_reaction_on_request(update): @@ -39,5 +61,7 @@ async def handle_reaction_on_request(update): new_reaction = reaction.get("new_reaction") if new_reaction.get("emoji") in positive_reactions: # за пользователя поручились - r = await telegram_api("approveChatJoinRequest", chat_id=chat_id, user_id=from_id) + r = await telegram_api( + "approveChatJoinRequest", chat_id=chat_id, user_id=from_id + ) logger.debug(r) diff --git a/handlers/handle_private.py b/handlers/handle_private.py index e66dd15..3350f4e 100644 --- a/handlers/handle_private.py +++ b/handlers/handle_private.py @@ -2,14 +2,14 @@ from bot.config import FEEDBACK_CHAT_ID from bot.announce import edit_announce from bot.api import telegram_api import logging -from utils.store import get_all_pattern, get_average_pattern +from state.scan import get_all_pattern, get_average_pattern logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) start_message = { - 'en': "Welcome home! You can type any message here to be passed to chat", - 'ru': "Доброе утро! Можешь напечатать здесь любое сообщение для передачи в чат" + "en": "Welcome home! You can type any message here to be passed to chat", + "ru": "Доброе утро! Можешь напечатать здесь любое сообщение для передачи в чат", } @@ -21,25 +21,35 @@ async def handle_private(msg, state): if lang != "ru" and lang != "en": lang = "en" if text and text.startswith("/"): - if text == '/start': + if text == "/start": await telegram_api("sendMessage", chat_id=uid, text=start_message[lang]) - state['welcome'] = True - elif state.get('welcome'): + state["welcome"] = True + elif state.get("welcome"): await edit_announce(msg) - state['welcome'] = False + state["welcome"] = False return - elif text.startswith('/toxic'): + elif text.startswith("/toxic"): + cid = msg.get("chat", {}).get("id") toxic_pattern = f"toxic:{uid}:{cid}:*" - toxic_score = await get_average_toxic(toxic_pattern) + toxic_score = await get_average_pattern(toxic_pattern) text = f"Средняя токсичность сообщений: {toxic_score}%" mid = msg.get("message_id") - await telegram_api("sendMessage", chat_id=uid, reply_to_message_id=mid, text=text) + await telegram_api( + "sendMessage", chat_id=uid, reply_to_message_id=mid, text=text + ) return - elif text == '/removed': + elif text == "/removed": removed_pattern = f"removed:{uid}:*" removed_messages = await get_all_pattern(removed_pattern) if removed_messages: - await telegram_api("sendMessage", chat_id=uid, text="\n\n".join(removed_messages)) + await telegram_api( + "sendMessage", chat_id=uid, text="\n\n".join(removed_messages) + ) return - await telegram_api("forwardMessage", from_chat_id=sender.get("id"), message_id=msg.get("message_id"), chat_id=FEEDBACK_CHAT_ID) + await telegram_api( + "forwardMessage", + from_chat_id=sender.get("id"), + message_id=msg.get("message_id"), + chat_id=FEEDBACK_CHAT_ID, + ) diff --git a/handlers/messages_routing.py b/handlers/messages_routing.py index 26860ff..114733d 100644 --- a/handlers/messages_routing.py +++ b/handlers/messages_routing.py @@ -1,15 +1,17 @@ import logging import math -from utils.store import redis, get_average_pattern +from state.redis import redis +from state.scan import get_average_pattern from bot.api import telegram_api from bot.config import FEEDBACK_CHAT_ID -from nlp.toxicity_detector import detector from handlers.handle_private import handle_private -from utils.normalize import normalize +from nlp.toxicity_detector import detector +from nlp.normalize import normalize -logger = logging.getLogger('handlers.messages_routing') +logger = logging.getLogger("handlers.messages_routing") logging.basicConfig(level=logging.DEBUG) + async def messages_routing(msg, state): cid = msg["chat"]["id"] uid = msg["from"]["id"] @@ -28,11 +30,16 @@ async def messages_routing(msg, state): if reply_msg: reply_chat_id = reply_msg.get("chat", {}).get("id") if reply_chat_id != FEEDBACK_CHAT_ID: - await telegram_api("sendMessage", chat_id=reply_chat_id, text=text, reply_to_message_id=reply_msg.get("message_id")) + await telegram_api( + "sendMessage", + chat_id=reply_chat_id, + text=text, + reply_to_message_id=reply_msg.get("message_id"), + ) elif bool(text): mid = msg.get("message_id") - if text == '/toxic@welcomecenter_bot': + if text == "/toxic@welcomecenter_bot": # latest in chat latest_toxic_message_id = await redis.get(f"toxic:{cid}") @@ -52,46 +59,44 @@ async def messages_routing(msg, state): one_score = await redis.get(f"toxic:{cid}:{uid}:{reply_to_msg_id}") if one_score: logger.debug(one_score) - emoji = '😳' if toxic_score > 90 else '😟' if toxic_score > 80 else '😏' if toxic_score > 60 else '🙂' if toxic_score > 20 else '😇' + emoji = ( + "😳" + if toxic_score > 90 + else "😟" + if toxic_score > 80 + else "😏" + if toxic_score > 60 + else "🙂" + if toxic_score > 20 + else "😇" + ) text = f"{int(one_score)}% токсичности\nСредняя токсичность сообщений: {toxic_score}% {emoji}" await telegram_api( "sendMessage", chat_id=cid, reply_to_message_id=reply_to_msg_id, - text=text + text=text, ) - await telegram_api( - "deleteMessage", - chat_id=cid, - message_id=mid - ) - elif text == '/removed@welcomecenter_bot': - await telegram_api( - "deleteMessage", - chat_id=cid, - message_id=mid - ) + await telegram_api("deleteMessage", chat_id=cid, message_id=mid) + elif text == "/removed@welcomecenter_bot": + await telegram_api("deleteMessage", chat_id=cid, message_id=mid) else: toxic_score = detector(normalize(text)) - toxic_perc = math.floor(toxic_score*100) + toxic_perc = math.floor(toxic_score * 100) await redis.set(f"toxic:{cid}", mid) - await redis.set(f"toxic:{cid}:{uid}:{mid}", toxic_perc, ex=60*60*24*3) - logger.info(f'\ntext: {text}\ntoxic: {toxic_perc}%') + await redis.set(f"toxic:{cid}:{uid}:{mid}", toxic_perc, ex=60 * 60 * 24 * 3) + logger.info(f"\ntext: {text}\ntoxic: {toxic_perc}%") if toxic_score > 0.81: if toxic_score > 0.90: await redis.set(f"removed:{uid}:{cid}:{mid}", text) - await telegram_api( - "deleteMessage", - chat_id=cid, - message_id=mid - ) + await telegram_api("deleteMessage", chat_id=cid, message_id=mid) else: await telegram_api( "setMessageReaction", chat_id=cid, is_big=True, message_id=mid, - reaction=f'[{{"type":"emoji", "emoji":"🙉"}}]' + reaction='[{"type":"emoji", "emoji":"🙉"}]', ) else: diff --git a/main.py b/main.py index d731556..46e2fc2 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,11 @@ async def start(): logger.info("\n\npolling started\n\n") offset = 0 # init offset while True: - response = await telegram_api("getUpdates", offset=offset, allowed_updates=['message', 'message_reaction', 'chat_join_request']) + response = await telegram_api( + "getUpdates", + offset=offset, + allowed_updates=["message", "message_reaction", "chat_join_request"], + ) # logger.debug(response) if isinstance(response, dict): result = response.get("result", []) @@ -35,15 +39,21 @@ async def start(): except Exception as e: logger.error(e) import traceback + text = traceback.format_exc() formatted_text = f"```log\n{text}```" - await telegram_api("sendMessage", chat_id=FEEDBACK_CHAT_ID, text=formatted_text, parse_mode='MarkdownV2') + await telegram_api( + "sendMessage", + chat_id=FEEDBACK_CHAT_ID, + text=formatted_text, + parse_mode="MarkdownV2", + ) offset = update["update_id"] + 1 await asyncio.sleep(1.0) else: - logger.error(' \n\n\n!!! getUpdates polling error\n\n\n') + logger.error(" \n\n\n!!! getUpdates polling error\n\n\n") await asyncio.sleep(30.0) diff --git a/utils/normalize.py b/nlp/normalize.py similarity index 54% rename from utils/normalize.py rename to nlp/normalize.py index 37b2b9d..319cc58 100644 --- a/utils/normalize.py +++ b/nlp/normalize.py @@ -1,4 +1,3 @@ -import logging import torch from transformers import T5Tokenizer, T5ForConditionalGeneration @@ -6,75 +5,81 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration tokenizer = T5Tokenizer.from_pretrained("google/byt5-small") model = T5ForConditionalGeneration.from_pretrained("google/byt5-small") + def is_russian_wording(text): """ - Check if the text contains any Russian characters by checking + Check if the text contains any Russian characters by checking each character against the Unicode range for Cyrillic. """ for char in text: - if '\u0400' <= char <= '\u04FF': # Unicode range for Cyrillic characters + if "\u0400" <= char <= "\u04ff": # Unicode range for Cyrillic characters return True return False + def segment_text(text): """ Use a neural network model to segment text into words. """ # Encode the input text for the model inputs = tokenizer.encode("segment: " + text, return_tensors="pt") - + # Generate predictions with torch.no_grad(): outputs = model.generate(inputs) - + # Decode the generated tokens back to text segmented_text = tokenizer.decode(outputs[0], skip_special_tokens=True) - + return segmented_text + def normalize(text): """ Normalize English text to resemble Russian characters. """ # Segment the text first - segmented_text = segment_text(text.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ')) - + segmented_text = segment_text( + text.replace(" ", " ").replace(" ", " ").replace(" ", " ") + ) + # Normalize after segmentation segmented_text = segmented_text.lower() - + if is_russian_wording(segmented_text): # Normalize the text by replacing characters - normalized_text = (segmented_text - .replace('e', 'е') - .replace('o', 'о') - .replace('x', 'х') - .replace('a', 'а') - .replace('r', 'г') - .replace('m', 'м') - .replace('u', 'и') - .replace('n', 'п') - .replace('p', 'р') - .replace('t', 'т') - .replace('y', 'у') - .replace('h', 'н') - .replace('i', 'й') - .replace('c', 'с') - .replace('k', 'к') - .replace('b', 'в') - .replace('3', 'з') - .replace('4', 'ч') - .replace('0', 'о') - .replace('d', 'д') - .replace('z', 'з')) - + normalized_text = ( + segmented_text.replace("e", "е") + .replace("o", "о") + .replace("x", "х") + .replace("a", "а") + .replace("r", "г") + .replace("m", "м") + .replace("u", "и") + .replace("n", "п") + .replace("p", "р") + .replace("t", "т") + .replace("y", "у") + .replace("h", "н") + .replace("i", "й") + .replace("c", "с") + .replace("k", "к") + .replace("b", "в") + .replace("3", "з") + .replace("4", "ч") + .replace("0", "о") + .replace("d", "д") + .replace("z", "з") + ) + return normalized_text - + return segmented_text + # Example usage if __name__ == "__main__": input_text = "Hello, this is a test input." - + normalized_output = normalize(input_text) print(normalized_output) - diff --git a/nlp/toxicity_detector.py b/nlp/toxicity_detector.py index 19c32c8..a56a21a 100644 --- a/nlp/toxicity_detector.py +++ b/nlp/toxicity_detector.py @@ -3,17 +3,22 @@ import torch import torch.nn.functional as F # Load tokenizer and model weights -tokenizer = BertTokenizer.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier') -model = BertForSequenceClassification.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier') +tokenizer = BertTokenizer.from_pretrained( + "SkolkovoInstitute/russian_toxicity_classifier" +) +model = BertForSequenceClassification.from_pretrained( + "SkolkovoInstitute/russian_toxicity_classifier" +) + def detector(text): # Prepare the input - batch = tokenizer.encode(text, return_tensors='pt') + batch = tokenizer.encode(text, return_tensors="pt") # Inference with torch.no_grad(): result = model(batch) - + # Get logits logits = result.logits @@ -22,9 +27,11 @@ def detector(text): return probabilities[0][1].item() + if __name__ == "__main__": import sys + if len(sys.argv) > 1: p = detector(sys.argv[1]) toxicity_percentage = p * 100 # Assuming index 1 is for toxic class - print(f"Toxicity Probability: {toxicity_percentage:.2f}%") \ No newline at end of file + print(f"Toxicity Probability: {toxicity_percentage:.2f}%") diff --git a/utils/store.py b/state/redis.py similarity index 72% rename from utils/store.py rename to state/redis.py index 96f7943..c87f4fa 100644 --- a/utils/store.py +++ b/state/redis.py @@ -1,11 +1,10 @@ from bot.config import REDIS_URL -import asyncio import redis.asyncio as aredis import logging # Create a logger instance -logger = logging.getLogger('store') -logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger("state.redis") +logging.basicConfig(level=logging.WARNING) class RedisService: @@ -68,9 +67,9 @@ class RedisService: # Execute the command with the provided arguments await self.execute("set", *args) - async def scan_iter(self, pattern='*'): + async def scan_iter(self, pattern="*"): """Asynchronously iterate over keys matching the given pattern.""" - cursor = '0' + cursor = "0" while cursor != 0: cursor, keys = await self._client.scan(cursor=cursor, match=pattern) for key in keys: @@ -83,32 +82,3 @@ class RedisService: redis = RedisService() __all__ = ["redis"] - - - -async def get_all_pattern(uid): - pattern = f"removed:{uid}:*" - - # Create a dictionary to hold the keys and values - texts = [] - - # Use scan_iter to find all keys matching the pattern - async for key in redis.scan_iter(pattern): - # Fetch the value for each key - value = await redis.get(key) - if value: - texts.append(value.decode('utf-8')) - - return texts - - -async def get_average_pattern(pattern): - scores = [] - scoring_msg_id = 0 - async for key in redis.scan_iter(pattern): - scr = await redis.get(key) - if isinstance(scr, int): - scores.append(scr) - logger.debug(f'found {len(scores)} messages') - toxic_score = math.floor(sum(scores)/len(scores)) if scores else 0 - return toxic_score \ No newline at end of file diff --git a/state/scan.py b/state/scan.py new file mode 100644 index 0000000..df2304c --- /dev/null +++ b/state/scan.py @@ -0,0 +1,34 @@ +from state.redis import redis +import logging +import math + +# Create a logger instance +logger = logging.getLogger("state.scan") +logging.basicConfig(level=logging.DEBUG) + + +async def get_all_pattern(uid): + pattern = f"removed:{uid}:*" + + # Create a dictionary to hold the keys and values + texts = [] + + # Use scan_iter to find all keys matching the pattern + async for key in redis.scan_iter(pattern): + # Fetch the value for each key + value = await redis.get(key) + if value: + texts.append(value.decode("utf-8")) + + return texts + + +async def get_average_pattern(pattern): + scores = [] + async for key in redis.scan_iter(pattern): + scr = await redis.get(key) + if isinstance(scr, int): + scores.append(scr) + logger.debug(f"found {len(scores)} messages") + toxic_score = math.floor(sum(scores) / len(scores)) if scores else 0 + return toxic_score diff --git a/utils/graph.py b/utils/graph.py index 2b40254..b2c7fd3 100644 --- a/utils/graph.py +++ b/utils/graph.py @@ -1,8 +1,9 @@ - import logging + logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) + # Define SVG code generation function with member_id parameter def generate_chart(members, member_id=None): if not member_id: diff --git a/utils/mention.py b/utils/mention.py index 09e7059..9bfbe48 100644 --- a/utils/mention.py +++ b/utils/mention.py @@ -12,7 +12,7 @@ def mention(user): def userdata_extract(user): - ln = " " + user.get('last_name', "") if user.get('last_name', "") else "" + ln = " " + user.get("last_name", "") if user.get("last_name", "") else "" identity = f"{user['first_name']}{ln}" uid = user["id"] username = user.get("username", "")