ocr
This commit is contained in:
parent
98b842ef18
commit
32ce2e17c4
23
bot/api.py
23
bot/api.py
|
@ -1,4 +1,6 @@
|
|||
import aiohttp
|
||||
import aiofiles
|
||||
import os
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
from bot.config import BOT_TOKEN
|
||||
|
@ -7,7 +9,6 @@ import logging
|
|||
# Create a logger instance
|
||||
logger = logging.getLogger("bot.api")
|
||||
|
||||
|
||||
api_base = f"https://api.telegram.org/bot{BOT_TOKEN}/"
|
||||
|
||||
|
||||
|
@ -31,3 +32,23 @@ async def telegram_api(endpoint: str, json_data=None, **kwargs):
|
|||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def download_file(file_id):
|
||||
"""Asynchronously download a file from Telegram and yield the temporary file path."""
|
||||
download_url = f"{api_base}/{file_path}"
|
||||
|
||||
# Get the file path of the file using the telegram_api method
|
||||
file = await telegram_api("getFile", file_id=file_id)
|
||||
file_path = file["result"]["file_path"]
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(download_url) as response:
|
||||
if response.status == 200:
|
||||
# Save the downloaded file to a temporary location
|
||||
async with aiofiles.tempfile.NamedTemporaryFile(delete=True) as temp_file:
|
||||
await temp_file.write(await response.read())
|
||||
await temp_file.flush()
|
||||
yield temp_file.name # Yield the path of the temporary file
|
||||
else:
|
||||
raise Exception(f"Failed to download file: {response.status}")
|
|
@ -7,6 +7,7 @@ from bot.config import FEEDBACK_CHAT_ID
|
|||
from handlers.handle_private import handle_private
|
||||
from nlp.toxicity_detector import detector
|
||||
from nlp.normalize import normalize
|
||||
from nlp.ocr import ocr_recognize
|
||||
|
||||
logger = logging.getLogger("handlers.messages_routing")
|
||||
|
||||
|
@ -14,7 +15,15 @@ logger = logging.getLogger("handlers.messages_routing")
|
|||
async def messages_routing(msg, state):
|
||||
cid = msg["chat"]["id"]
|
||||
uid = msg["from"]["id"]
|
||||
text = msg.get("text", msg.get("caption"))
|
||||
text = msg.get("caption", msg.get("text", ""))
|
||||
|
||||
for photo in msg.get("photo", [])
|
||||
file_id = photo.get("file_id")
|
||||
if file_id:
|
||||
async for temp_file_path in download_file(file_id):
|
||||
text += ocr_recognize(temp_file_path)
|
||||
text += '\n'
|
||||
|
||||
reply_msg = msg.get("reply_to_message")
|
||||
|
||||
if cid == uid:
|
||||
|
|
16
nlp/ocr.py
Normal file
16
nlp/ocr.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import easyocr
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("[ocr] ")
|
||||
|
||||
# Initialize the EasyOCR reader
|
||||
reader = easyocr.Reader(['ru']) # Specify the languages you want to support
|
||||
|
||||
def ocr_recognize(file_path):
|
||||
# Use EasyOCR to detect text in the photo
|
||||
result = reader.readtext(file_path)
|
||||
|
||||
# Extract the recognized text from the result
|
||||
recognized_text = ' '.join([text for text, _, _ in result])
|
||||
logger.debug(f'recognized_text: {recognized_text}')
|
||||
return recognized_text
|
|
@ -2,5 +2,6 @@ redis[hiredis]
|
|||
aiohttp
|
||||
torch
|
||||
transformers
|
||||
easyocr
|
||||
# protobuf
|
||||
# sentencepiece
|
Loading…
Reference in New Issue
Block a user