File size: 3,411 Bytes
a15e210 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import json
import torch
import torch.nn as nn
from dataclasses import dataclass
from preprocessing import preprocess_single_string
with open('model_data/vocab_kinopoisk_lstm.json', 'r') as file:
vocab_to_int = json.load(file)
@dataclass
class ConfigRNN:
vocab_size: int
device : str
n_layers : int
embedding_dim : int
hidden_size : int
seq_len : int
bidirectional : bool or int
net_config = ConfigRNN(
vocab_size = len(vocab_to_int)+1,
device='cpu',
n_layers=3,
embedding_dim=64,
hidden_size=64,
seq_len = 100,
bidirectional=False
)
class LSTMClassifier(nn.Module):
def __init__(self, rnn_conf = net_config) -> None:
super().__init__()
self.embedding_dim = rnn_conf.embedding_dim
self.hidden_size = rnn_conf.hidden_size
self.bidirectional = rnn_conf.bidirectional
self.n_layers = rnn_conf.n_layers
self.embedding = nn.Embedding(rnn_conf.vocab_size, self.embedding_dim)
self.lstm = nn.LSTM(
input_size = self.embedding_dim,
hidden_size = self.hidden_size,
bidirectional = self.bidirectional,
batch_first = True,
num_layers = self.n_layers
)
self.bidirect_factor = 2 if self.bidirectional else 1
self.clf = nn.Sequential(
nn.Linear(self.hidden_size * self.bidirect_factor, 32),
nn.Tanh(),
nn.Dropout(),
nn.Linear(32, 3)
)
def model_description(self):
direction = 'bidirect' if self.bidirectional else 'onedirect'
return f'lstm_{direction}_{self.n_layers}'
def forward(self, x: torch.Tensor):
embeddings = self.embedding(x)
out, _ = self.lstm(embeddings)
out = out[:, -1, :] # [все элементы батча, последний h_n, все элементы последнего h_n]
out = self.clf(out.squeeze())
return out
def load_lstm_model():
model = LSTMClassifier()
model.load_state_dict(torch.load('model_data/lstm_model.pth'))
model.eval()
return model
model = load_lstm_model()
def predict_review(review_text, model=model, net_config=net_config, vocab_to_int=vocab_to_int):
sample = preprocess_single_string(review_text, net_config.seq_len, vocab_to_int)
model.eval()
with torch.no_grad():
output = model(sample.unsqueeze(0)).to(net_config.device)
if output.dim() == 1:
output = output.unsqueeze(0) # Adjust if necessary
_, predicted_class = torch.max(output, dim=1)
if predicted_class.item() == 0:
return "Это положительный комментарий! Хорошо, что тебе понравился этот фильм! Можешь перейти в раздел с моделью GPT2 и обсудить с ней фильм!"
elif predicted_class.item() == 1:
return "Скорее всего... это комментарий нейтрального характера.. какой-то ты скучный..."
else:
return "Ты что такой токсик? Будь сдержанее, не понравился фильм - пройди мимо и не порьт авторам настроение, они же старались!" |