add-byt5
This commit is contained in:
parent
b9ac3ee3c6
commit
905b9b177c
|
@ -1,48 +1,80 @@
|
|||
import logging
|
||||
import torch
|
||||
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
# Initialize the T5 model and tokenizer
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/byt5-small")
|
||||
model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
|
||||
|
||||
def is_russian_wording(text):
|
||||
"""
|
||||
Check if the text contains any Russian characters by checking
|
||||
each character against the Unicode range for Cyrillic.
|
||||
"""
|
||||
# Check if any character in the text is a Cyrillic character
|
||||
for char in text:
|
||||
if '\u0400' <= char <= '\u04FF': # Unicode range for Cyrillic characters
|
||||
return True
|
||||
return False
|
||||
|
||||
def segment_text(text):
|
||||
"""
|
||||
Use a neural network model to segment text into words.
|
||||
"""
|
||||
# Encode the input text for the model
|
||||
inputs = tokenizer.encode("segment: " + text, return_tensors="pt")
|
||||
|
||||
# Generate predictions
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(inputs)
|
||||
|
||||
# Decode the generated tokens back to text
|
||||
segmented_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
return segmented_text
|
||||
|
||||
def normalize(text):
|
||||
"""
|
||||
Normalize English text to resemble Russian characters.
|
||||
"""
|
||||
text = text.lower()
|
||||
if is_russian_wording(text):
|
||||
# Normalize the text by replacing characters
|
||||
text = (text
|
||||
.replace(' ', ' ')
|
||||
.replace(' ', ' ')
|
||||
.replace(' ', ' ')
|
||||
.replace('e', 'е')
|
||||
.replace('o', 'о')
|
||||
.replace('x', 'х')
|
||||
.replace('a', 'а')
|
||||
.replace('r', 'г')
|
||||
.replace('m', 'м')
|
||||
.replace('u', 'и')
|
||||
.replace('n', 'п')
|
||||
.replace('p', 'р')
|
||||
.replace('t', 'т')
|
||||
.replace('y', 'у')
|
||||
.replace('h', 'н')
|
||||
.replace('p', 'р')
|
||||
.replace('i', 'й')
|
||||
.replace('c', 'с')
|
||||
.replace('k', 'к')
|
||||
.replace('b', 'в')
|
||||
.replace('3', 'з')
|
||||
.replace('4', 'ч')
|
||||
.replace('0', 'о')
|
||||
.replace('e', 'е')
|
||||
.replace('d', 'д')
|
||||
.replace('z', 'з')
|
||||
)
|
||||
# Segment the text first
|
||||
segmented_text = segment_text(text.replace(' ', ' ').replace(' ', ' ').replace(' ', ' '))
|
||||
|
||||
return text
|
||||
# Normalize after segmentation
|
||||
segmented_text = segmented_text.lower()
|
||||
|
||||
if is_russian_wording(segmented_text):
|
||||
# Normalize the text by replacing characters
|
||||
normalized_text = (segmented_text
|
||||
.replace('e', 'е')
|
||||
.replace('o', 'о')
|
||||
.replace('x', 'х')
|
||||
.replace('a', 'а')
|
||||
.replace('r', 'г')
|
||||
.replace('m', 'м')
|
||||
.replace('u', 'и')
|
||||
.replace('n', 'п')
|
||||
.replace('p', 'р')
|
||||
.replace('t', 'т')
|
||||
.replace('y', 'у')
|
||||
.replace('h', 'н')
|
||||
.replace('i', 'й')
|
||||
.replace('c', 'с')
|
||||
.replace('k', 'к')
|
||||
.replace('b', 'в')
|
||||
.replace('3', 'з')
|
||||
.replace('4', 'ч')
|
||||
.replace('0', 'о')
|
||||
.replace('d', 'д')
|
||||
.replace('z', 'з'))
|
||||
|
||||
return normalized_text
|
||||
|
||||
return segmented_text
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
input_text = "Hello, this is a test input."
|
||||
|
||||
normalized_output = normalize(input_text)
|
||||
print(normalized_output)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user