diff --git a/nlp/toxycity_detector.py b/nlp/toxycity_detector.py index 0cb2fe9..19c32c8 100644 --- a/nlp/toxycity_detector.py +++ b/nlp/toxycity_detector.py @@ -1,13 +1,30 @@ from transformers import BertTokenizer, BertForSequenceClassification +import torch +import torch.nn.functional as F -# load tokenizer and model weights +# Load tokenizer and model weights tokenizer = BertTokenizer.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier') model = BertForSequenceClassification.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier') - def detector(text): - # prepare the input + # Prepare the input batch = tokenizer.encode(text, return_tensors='pt') - # inference - model(batch) \ No newline at end of file + # Inference + with torch.no_grad(): + result = model(batch) + + # Get logits + logits = result.logits + + # Convert logits to probabilities using softmax + probabilities = F.softmax(logits, dim=1) + + return probabilities[0][1].item() + +if __name__ == "__main__": + import sys + if len(sys.argv) > 1: + p = detector(sys.argv[1]) + toxicity_percentage = p * 100 # Assuming index 1 is for toxic class + print(f"Toxicity Probability: {toxicity_percentage:.2f}%") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c784ec9..a29552b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ aiohttp redis[hiredis] -tensorflow +torch transformers \ No newline at end of file