ruffed
This commit is contained in:
@@ -3,17 +3,22 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Load tokenizer and model weights
|
||||
tokenizer = BertTokenizer.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
||||
model = BertForSequenceClassification.from_pretrained('SkolkovoInstitute/russian_toxicity_classifier')
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"SkolkovoInstitute/russian_toxicity_classifier"
|
||||
)
|
||||
model = BertForSequenceClassification.from_pretrained(
|
||||
"SkolkovoInstitute/russian_toxicity_classifier"
|
||||
)
|
||||
|
||||
|
||||
def detector(text):
|
||||
# Prepare the input
|
||||
batch = tokenizer.encode(text, return_tensors='pt')
|
||||
batch = tokenizer.encode(text, return_tensors="pt")
|
||||
|
||||
# Inference
|
||||
with torch.no_grad():
|
||||
result = model(batch)
|
||||
|
||||
|
||||
# Get logits
|
||||
logits = result.logits
|
||||
|
||||
@@ -22,9 +27,11 @@ def detector(text):
|
||||
|
||||
return probabilities[0][1].item()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
p = detector(sys.argv[1])
|
||||
toxicity_percentage = p * 100 # Assuming index 1 is for toxic class
|
||||
print(f"Toxicity Probability: {toxicity_percentage:.2f}%")
|
||||
print(f"Toxicity Probability: {toxicity_percentage:.2f}%")
|
||||
|
Reference in New Issue
Block a user