diff --git a/bot/api.py b/bot/api.py index 166c800..436ae37 100644 --- a/bot/api.py +++ b/bot/api.py @@ -1,4 +1,6 @@ import aiohttp +import aiofiles +import os import json from urllib.parse import urlencode from bot.config import BOT_TOKEN @@ -7,7 +9,6 @@ import logging # Create a logger instance logger = logging.getLogger("bot.api") - api_base = f"https://api.telegram.org/bot{BOT_TOKEN}/" @@ -31,3 +32,23 @@ async def telegram_api(endpoint: str, json_data=None, **kwargs): import traceback traceback.print_exc() + + +async def download_file(file_id): + """Asynchronously download a file from Telegram and yield the temporary file path.""" + download_url = f"{api_base}/{file_path}" + + # Get the file path of the file using the telegram_api method + file = await telegram_api("getFile", file_id=file_id) + file_path = file["result"]["file_path"] + + async with aiohttp.ClientSession() as session: + async with session.get(download_url) as response: + if response.status == 200: + # Save the downloaded file to a temporary location + async with aiofiles.tempfile.NamedTemporaryFile(delete=True) as temp_file: + await temp_file.write(await response.read()) + await temp_file.flush() + yield temp_file.name # Yield the path of the temporary file + else: + raise Exception(f"Failed to download file: {response.status}") \ No newline at end of file diff --git a/handlers/messages_routing.py b/handlers/messages_routing.py index 3aa7b28..3cae9af 100644 --- a/handlers/messages_routing.py +++ b/handlers/messages_routing.py @@ -7,6 +7,7 @@ from bot.config import FEEDBACK_CHAT_ID from handlers.handle_private import handle_private from nlp.toxicity_detector import detector from nlp.normalize import normalize +from nlp.ocr import ocr_recognize logger = logging.getLogger("handlers.messages_routing") @@ -14,7 +15,15 @@ logger = logging.getLogger("handlers.messages_routing") async def messages_routing(msg, state): cid = msg["chat"]["id"] uid = msg["from"]["id"] - text = msg.get("text", msg.get("caption")) + text = msg.get("caption", msg.get("text", "")) + + for photo in msg.get("photo", []) + file_id = photo.get("file_id") + if file_id: + async for temp_file_path in download_file(file_id): + text += ocr_recognize(temp_file_path) + text += '\n' + reply_msg = msg.get("reply_to_message") if cid == uid: diff --git a/nlp/ocr.py b/nlp/ocr.py new file mode 100644 index 0000000..924ec24 --- /dev/null +++ b/nlp/ocr.py @@ -0,0 +1,16 @@ +import easyocr +import logging + +logger = logging.getLogger("[ocr] ") + +# Initialize the EasyOCR reader +reader = easyocr.Reader(['ru']) # Specify the languages you want to support + +def ocr_recognize(file_path): + # Use EasyOCR to detect text in the photo + result = reader.readtext(file_path) + + # Extract the recognized text from the result + recognized_text = ' '.join([text for text, _, _ in result]) + logger.debug(f'recognized_text: {recognized_text}') + return recognized_text \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a0269bf..015ce7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,6 @@ redis[hiredis] aiohttp torch transformers +easyocr # protobuf # sentencepiece \ No newline at end of file