diff --git a/utils/normalize.py b/utils/normalize.py index bc15dbd..37b2b9d 100644 --- a/utils/normalize.py +++ b/utils/normalize.py @@ -1,48 +1,80 @@ +import logging +import torch +from transformers import T5Tokenizer, T5ForConditionalGeneration + +# Initialize the T5 model and tokenizer +tokenizer = T5Tokenizer.from_pretrained("google/byt5-small") +model = T5ForConditionalGeneration.from_pretrained("google/byt5-small") + def is_russian_wording(text): """ Check if the text contains any Russian characters by checking each character against the Unicode range for Cyrillic. """ - # Check if any character in the text is a Cyrillic character for char in text: if '\u0400' <= char <= '\u04FF': # Unicode range for Cyrillic characters return True return False +def segment_text(text): + """ + Use a neural network model to segment text into words. + """ + # Encode the input text for the model + inputs = tokenizer.encode("segment: " + text, return_tensors="pt") + + # Generate predictions + with torch.no_grad(): + outputs = model.generate(inputs) + + # Decode the generated tokens back to text + segmented_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + return segmented_text + def normalize(text): """ Normalize English text to resemble Russian characters. """ - text = text.lower() - if is_russian_wording(text): - # Normalize the text by replacing characters - text = (text - .replace(' ', ' ') - .replace(' ', ' ') - .replace(' ', ' ') - .replace('e', 'е') - .replace('o', 'о') - .replace('x', 'х') - .replace('a', 'а') - .replace('r', 'г') - .replace('m', 'м') - .replace('u', 'и') - .replace('n', 'п') - .replace('p', 'р') - .replace('t', 'т') - .replace('y', 'у') - .replace('h', 'н') - .replace('p', 'р') - .replace('i', 'й') - .replace('c', 'с') - .replace('k', 'к') - .replace('b', 'в') - .replace('3', 'з') - .replace('4', 'ч') - .replace('0', 'о') - .replace('e', 'е') - .replace('d', 'д') - .replace('z', 'з') - ) + # Segment the text first + segmented_text = segment_text(text.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ')) - return text + # Normalize after segmentation + segmented_text = segmented_text.lower() + + if is_russian_wording(segmented_text): + # Normalize the text by replacing characters + normalized_text = (segmented_text + .replace('e', 'е') + .replace('o', 'о') + .replace('x', 'х') + .replace('a', 'а') + .replace('r', 'г') + .replace('m', 'м') + .replace('u', 'и') + .replace('n', 'п') + .replace('p', 'р') + .replace('t', 'т') + .replace('y', 'у') + .replace('h', 'н') + .replace('i', 'й') + .replace('c', 'с') + .replace('k', 'к') + .replace('b', 'в') + .replace('3', 'з') + .replace('4', 'ч') + .replace('0', 'о') + .replace('d', 'д') + .replace('z', 'з')) + + return normalized_text + + return segmented_text + +# Example usage +if __name__ == "__main__": + input_text = "Hello, this is a test input." + + normalized_output = normalize(input_text) + print(normalized_output) +