88 lines
2.4 KiB
Python
88 lines
2.4 KiB
Python
import torch
|
||
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
|
||
|
||
# Use ByT5 for the ByT5 model
|
||
tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
|
||
model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
|
||
|
||
|
||
def is_russian_wording(text):
|
||
"""
|
||
Check if the text contains more than one Russian characters by checking
|
||
each character against the Unicode range for Cyrillic.
|
||
"""
|
||
counter = 0
|
||
for char in text:
|
||
if "\u0400" <= char <= "\u04ff": # Unicode range for Cyrillic characters
|
||
counter += 1
|
||
if counter > 1:
|
||
return True
|
||
return False
|
||
|
||
|
||
def segment_text(text):
|
||
"""
|
||
Use a neural network model to segment text into words.
|
||
"""
|
||
# Encode the input text for the model
|
||
# inputs = tokenizer.encode("segment: " + text, return_tensors="pt")
|
||
inputs = tokenizer("segment: " + input_text, return_tensors="pt")
|
||
|
||
# Generate predictions
|
||
with torch.no_grad():
|
||
outputs = model.generate(**inputs)
|
||
|
||
# Decode the generated tokens back to text
|
||
segmented_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
|
||
return segmented_text
|
||
|
||
|
||
def normalize(text):
|
||
"""
|
||
Normalize English text to resemble Russian characters.
|
||
"""
|
||
# Segment the text first
|
||
t = segment_text(
|
||
text.replace(" ", " ").replace(" ", " ").replace(" ", " ")
|
||
)
|
||
|
||
if is_russian_wording(t):
|
||
# Normalize the text by replacing characters
|
||
normalized_text = (
|
||
t.lower()
|
||
.replace("e", "е")
|
||
.replace("o", "о")
|
||
.replace("x", "х")
|
||
.replace("a", "а")
|
||
.replace("r", "г")
|
||
.replace("m", "м")
|
||
.replace("u", "и")
|
||
.replace("n", "п")
|
||
.replace("p", "р")
|
||
.replace("t", "т")
|
||
.replace("y", "у")
|
||
.replace("h", "н")
|
||
.replace("i", "й")
|
||
.replace("c", "с")
|
||
.replace("k", "к")
|
||
.replace("b", "в")
|
||
.replace("3", "з")
|
||
.replace("4", "ч")
|
||
.replace("0", "о")
|
||
.replace("d", "д")
|
||
.replace("z", "з")
|
||
)
|
||
|
||
return normalized_text
|
||
|
||
return t
|
||
|
||
|
||
# Example usage
|
||
if __name__ == "__main__":
|
||
input_text = "привет шп ана т у п а я"
|
||
|
||
normalized_output = normalize(input_text)
|
||
print(normalized_output)
|