diff --git a/handlers/messages_routing.py b/handlers/messages_routing.py index 256a792..118b626 100644 --- a/handlers/messages_routing.py +++ b/handlers/messages_routing.py @@ -5,6 +5,7 @@ from state.scan import get_average_pattern from bot.api import telegram_api, download_file from bot.config import FEEDBACK_CHAT_ID from handlers.handle_private import handle_private +from nlp.segment_text import segment_text from nlp.toxicity_detector import detector from nlp.normalize import normalize from nlp.ocr import ocr_recognize @@ -16,22 +17,9 @@ async def messages_routing(msg, state): cid = msg["chat"]["id"] uid = msg["from"]["id"] text = msg.get("caption", msg.get("text", "")) - - for photo in msg.get("photo", []): - file_id = photo.get("file_id") - if file_id: - async for temp_file_path in download_file(file_id): - text += ocr_recognize(temp_file_path) - text += '\n' - reply_msg = msg.get("reply_to_message") - if cid == uid: - # сообщения в личке с ботом - logger.info("private chat message: ", msg) - await handle_private(msg, state) - - elif str(cid) == FEEDBACK_CHAT_ID: + if str(cid) == FEEDBACK_CHAT_ID: # сообщения из группы обратной связи logger.info("feedback chat message: ", msg) logger.debug(msg) @@ -44,8 +32,15 @@ async def messages_routing(msg, state): text=text, reply_to_message_id=reply_msg.get("message_id"), ) + return - elif bool(text): + elif cid == uid: + # сообщения в личке с ботом + logger.info("private chat message: ", msg) + await handle_private(msg, state) + return + + elif bool(text) or msg.get("photo"): mid = msg.get("message_id") if text == "/toxic@welcomecenter_bot": # latest in chat @@ -106,8 +101,17 @@ async def messages_routing(msg, state): except Exception: pass else: + # on screen recognition + for photo in msg.get("photo", []): + file_id = photo.get("file_id") + if file_id: + async for temp_file_path in download_file(file_id): + text += ocr_recognize(temp_file_path) + text += '\n' + normalized_text = normalize(text) - toxic_score = detector(normalized_text) + segmented_text = segment_text(normalized_text) + toxic_score = detector(segmented_text) toxic_perc = math.floor(toxic_score * 100) if toxic_perc > 49: diff --git a/nlp/segment_text.py b/nlp/segment_text.py index 67cbea8..56af820 100644 --- a/nlp/segment_text.py +++ b/nlp/segment_text.py @@ -1,25 +1,17 @@ -import torch -from transformers import ByT5Tokenizer, T5ForConditionalGeneration - -# Use ByT5 for the ByT5 model -tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small") -model = T5ForConditionalGeneration.from_pretrained("google/byt5-small") - - +import spacy +# Load the Russian language model +nlp = spacy.load("ru_core_news_sm") def segment_text(text): """ - Use a neural network model to segment text into words. + Use SpaCy to segment text into words. """ - # Encode the input text for the model as UTF-8 bytes - inputs = tokenizer.encode("segment: " + text, return_tensors="pt") + # Process the text with SpaCy + doc = nlp(text) - # Generate predictions - with torch.no_grad(): - outputs = model.generate(inputs) - - # Decode the generated tokens back to text - segmented_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + # Extract words from the processed document + segmented_text = ' '.join([token.text for token in doc]) return segmented_text + diff --git a/requirements.txt b/requirements.txt index 59ed4f7..c4d9f2b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ redis[hiredis] aiohttp aiofiles -torch +spacy transformers easyocr # protobuf