spacy-words-separation
This commit is contained in:
parent
d9e9c547ef
commit
56a2632980
|
@ -5,6 +5,7 @@ from state.scan import get_average_pattern
|
||||||
from bot.api import telegram_api, download_file
|
from bot.api import telegram_api, download_file
|
||||||
from bot.config import FEEDBACK_CHAT_ID
|
from bot.config import FEEDBACK_CHAT_ID
|
||||||
from handlers.handle_private import handle_private
|
from handlers.handle_private import handle_private
|
||||||
|
from nlp.segment_text import segment_text
|
||||||
from nlp.toxicity_detector import detector
|
from nlp.toxicity_detector import detector
|
||||||
from nlp.normalize import normalize
|
from nlp.normalize import normalize
|
||||||
from nlp.ocr import ocr_recognize
|
from nlp.ocr import ocr_recognize
|
||||||
|
@ -16,22 +17,9 @@ async def messages_routing(msg, state):
|
||||||
cid = msg["chat"]["id"]
|
cid = msg["chat"]["id"]
|
||||||
uid = msg["from"]["id"]
|
uid = msg["from"]["id"]
|
||||||
text = msg.get("caption", msg.get("text", ""))
|
text = msg.get("caption", msg.get("text", ""))
|
||||||
|
|
||||||
for photo in msg.get("photo", []):
|
|
||||||
file_id = photo.get("file_id")
|
|
||||||
if file_id:
|
|
||||||
async for temp_file_path in download_file(file_id):
|
|
||||||
text += ocr_recognize(temp_file_path)
|
|
||||||
text += '\n'
|
|
||||||
|
|
||||||
reply_msg = msg.get("reply_to_message")
|
reply_msg = msg.get("reply_to_message")
|
||||||
|
|
||||||
if cid == uid:
|
if str(cid) == FEEDBACK_CHAT_ID:
|
||||||
# сообщения в личке с ботом
|
|
||||||
logger.info("private chat message: ", msg)
|
|
||||||
await handle_private(msg, state)
|
|
||||||
|
|
||||||
elif str(cid) == FEEDBACK_CHAT_ID:
|
|
||||||
# сообщения из группы обратной связи
|
# сообщения из группы обратной связи
|
||||||
logger.info("feedback chat message: ", msg)
|
logger.info("feedback chat message: ", msg)
|
||||||
logger.debug(msg)
|
logger.debug(msg)
|
||||||
|
@ -44,8 +32,15 @@ async def messages_routing(msg, state):
|
||||||
text=text,
|
text=text,
|
||||||
reply_to_message_id=reply_msg.get("message_id"),
|
reply_to_message_id=reply_msg.get("message_id"),
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
elif bool(text):
|
elif cid == uid:
|
||||||
|
# сообщения в личке с ботом
|
||||||
|
logger.info("private chat message: ", msg)
|
||||||
|
await handle_private(msg, state)
|
||||||
|
return
|
||||||
|
|
||||||
|
elif bool(text) or msg.get("photo"):
|
||||||
mid = msg.get("message_id")
|
mid = msg.get("message_id")
|
||||||
if text == "/toxic@welcomecenter_bot":
|
if text == "/toxic@welcomecenter_bot":
|
||||||
# latest in chat
|
# latest in chat
|
||||||
|
@ -106,8 +101,17 @@ async def messages_routing(msg, state):
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
# on screen recognition
|
||||||
|
for photo in msg.get("photo", []):
|
||||||
|
file_id = photo.get("file_id")
|
||||||
|
if file_id:
|
||||||
|
async for temp_file_path in download_file(file_id):
|
||||||
|
text += ocr_recognize(temp_file_path)
|
||||||
|
text += '\n'
|
||||||
|
|
||||||
normalized_text = normalize(text)
|
normalized_text = normalize(text)
|
||||||
toxic_score = detector(normalized_text)
|
segmented_text = segment_text(normalized_text)
|
||||||
|
toxic_score = detector(segmented_text)
|
||||||
toxic_perc = math.floor(toxic_score * 100)
|
toxic_perc = math.floor(toxic_score * 100)
|
||||||
|
|
||||||
if toxic_perc > 49:
|
if toxic_perc > 49:
|
||||||
|
|
|
@ -1,25 +1,17 @@
|
||||||
import torch
|
import spacy
|
||||||
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
|
|
||||||
|
|
||||||
# Use ByT5 for the ByT5 model
|
|
||||||
tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
|
|
||||||
model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Load the Russian language model
|
||||||
|
nlp = spacy.load("ru_core_news_sm")
|
||||||
|
|
||||||
def segment_text(text):
|
def segment_text(text):
|
||||||
"""
|
"""
|
||||||
Use a neural network model to segment text into words.
|
Use SpaCy to segment text into words.
|
||||||
"""
|
"""
|
||||||
# Encode the input text for the model as UTF-8 bytes
|
# Process the text with SpaCy
|
||||||
inputs = tokenizer.encode("segment: " + text, return_tensors="pt")
|
doc = nlp(text)
|
||||||
|
|
||||||
# Generate predictions
|
# Extract words from the processed document
|
||||||
with torch.no_grad():
|
segmented_text = ' '.join([token.text for token in doc])
|
||||||
outputs = model.generate(inputs)
|
|
||||||
|
|
||||||
# Decode the generated tokens back to text
|
|
||||||
segmented_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
||||||
|
|
||||||
return segmented_text
|
return segmented_text
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
redis[hiredis]
|
redis[hiredis]
|
||||||
aiohttp
|
aiohttp
|
||||||
aiofiles
|
aiofiles
|
||||||
torch
|
spacy
|
||||||
transformers
|
transformers
|
||||||
easyocr
|
easyocr
|
||||||
# protobuf
|
# protobuf
|
||||||
|
|
Loading…
Reference in New Issue
Block a user