From 4a025a5595ce749cf989d262d06dfbe72a0a305a Mon Sep 17 00:00:00 2001 From: Untone Date: Thu, 26 Sep 2024 13:24:18 +0300 Subject: [PATCH] fix-percentage --- nlp/toxycity_detector.py | 27 ++++++++++++++++++++++----- requirements.txt | 2 +- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/nlp/toxycity_detector.py b/nlp/toxycity_detector.py index 0cb2fe9..19c32c8 100644 --- a/nlp/toxycity_detector.py +++ b/nlp/toxycity_detector.py @@ -1,13 +1,30 @@ from transformers import BertTokenizer, BertForSequenceClassification +import torch +import torch.nn.functional as F -# load tokenizer and model weights +# Load tokenizer and model weights tokenizer = BertTokenizer.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier') model = BertForSequenceClassification.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier') - def detector(text): - # prepare the input + # Prepare the input batch = tokenizer.encode(text, return_tensors='pt') - # inference - model(batch) \ No newline at end of file + # Inference + with torch.no_grad(): + result = model(batch) + + # Get logits + logits = result.logits + + # Convert logits to probabilities using softmax + probabilities = F.softmax(logits, dim=1) + + return probabilities[0][1].item() + +if __name__ == "__main__": + import sys + if len(sys.argv) > 1: + p = detector(sys.argv[1]) + toxicity_percentage = p * 100 # Assuming index 1 is for toxic class + print(f"Toxicity Probability: {toxicity_percentage:.2f}%") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c784ec9..a29552b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ aiohttp redis[hiredis] -tensorflow +torch transformers \ No newline at end of file