spacy-words-separation
This commit is contained in:
parent
d9e9c547ef
commit
56a2632980
|
@ -5,6 +5,7 @@ from state.scan import get_average_pattern
|
|||
from bot.api import telegram_api, download_file
|
||||
from bot.config import FEEDBACK_CHAT_ID
|
||||
from handlers.handle_private import handle_private
|
||||
from nlp.segment_text import segment_text
|
||||
from nlp.toxicity_detector import detector
|
||||
from nlp.normalize import normalize
|
||||
from nlp.ocr import ocr_recognize
|
||||
|
@ -16,22 +17,9 @@ async def messages_routing(msg, state):
|
|||
cid = msg["chat"]["id"]
|
||||
uid = msg["from"]["id"]
|
||||
text = msg.get("caption", msg.get("text", ""))
|
||||
|
||||
for photo in msg.get("photo", []):
|
||||
file_id = photo.get("file_id")
|
||||
if file_id:
|
||||
async for temp_file_path in download_file(file_id):
|
||||
text += ocr_recognize(temp_file_path)
|
||||
text += '\n'
|
||||
|
||||
reply_msg = msg.get("reply_to_message")
|
||||
|
||||
if cid == uid:
|
||||
# сообщения в личке с ботом
|
||||
logger.info("private chat message: ", msg)
|
||||
await handle_private(msg, state)
|
||||
|
||||
elif str(cid) == FEEDBACK_CHAT_ID:
|
||||
if str(cid) == FEEDBACK_CHAT_ID:
|
||||
# сообщения из группы обратной связи
|
||||
logger.info("feedback chat message: ", msg)
|
||||
logger.debug(msg)
|
||||
|
@ -44,8 +32,15 @@ async def messages_routing(msg, state):
|
|||
text=text,
|
||||
reply_to_message_id=reply_msg.get("message_id"),
|
||||
)
|
||||
return
|
||||
|
||||
elif bool(text):
|
||||
elif cid == uid:
|
||||
# сообщения в личке с ботом
|
||||
logger.info("private chat message: ", msg)
|
||||
await handle_private(msg, state)
|
||||
return
|
||||
|
||||
elif bool(text) or msg.get("photo"):
|
||||
mid = msg.get("message_id")
|
||||
if text == "/toxic@welcomecenter_bot":
|
||||
# latest in chat
|
||||
|
@ -106,8 +101,17 @@ async def messages_routing(msg, state):
|
|||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# on screen recognition
|
||||
for photo in msg.get("photo", []):
|
||||
file_id = photo.get("file_id")
|
||||
if file_id:
|
||||
async for temp_file_path in download_file(file_id):
|
||||
text += ocr_recognize(temp_file_path)
|
||||
text += '\n'
|
||||
|
||||
normalized_text = normalize(text)
|
||||
toxic_score = detector(normalized_text)
|
||||
segmented_text = segment_text(normalized_text)
|
||||
toxic_score = detector(segmented_text)
|
||||
toxic_perc = math.floor(toxic_score * 100)
|
||||
|
||||
if toxic_perc > 49:
|
||||
|
|
|
@ -1,25 +1,17 @@
|
|||
import torch
|
||||
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
# Use ByT5 for the ByT5 model
|
||||
tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
|
||||
model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
|
||||
|
||||
|
||||
import spacy
|
||||
|
||||
# Load the Russian language model
|
||||
nlp = spacy.load("ru_core_news_sm")
|
||||
|
||||
def segment_text(text):
|
||||
"""
|
||||
Use a neural network model to segment text into words.
|
||||
Use SpaCy to segment text into words.
|
||||
"""
|
||||
# Encode the input text for the model as UTF-8 bytes
|
||||
inputs = tokenizer.encode("segment: " + text, return_tensors="pt")
|
||||
# Process the text with SpaCy
|
||||
doc = nlp(text)
|
||||
|
||||
# Generate predictions
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(inputs)
|
||||
|
||||
# Decode the generated tokens back to text
|
||||
segmented_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
# Extract words from the processed document
|
||||
segmented_text = ' '.join([token.text for token in doc])
|
||||
|
||||
return segmented_text
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
redis[hiredis]
|
||||
aiohttp
|
||||
aiofiles
|
||||
torch
|
||||
spacy
|
||||
transformers
|
||||
easyocr
|
||||
# protobuf
|
||||
|
|
Loading…
Reference in New Issue
Block a user