1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
| from datasets import load_dataset import torch import evaluate from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding from transformers import AdamW, get_scheduler from torch.utils.data import DataLoader from tqdm.auto import tqdm from accelerate import Accelerator
checkpoint = "bert-base-uncased" raw_datasets = load_dataset("glue", "sst2") tokenizer = AutoTokenizer.from_pretrained(checkpoint)
def tokenize_func(example): return tokenizer( example['sentence'], truncation = True ) tokenized_datasets = raw_datasets.map(tokenize_func,batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"]) tokenized_datasets = tokenized_datasets.rename_column("label", "labels") tokenized_datasets.set_format("torch")
data_collator = DataCollatorWithPadding(tokenizer) train_dataloader = DataLoader( tokenized_datasets['train'], shuffle = True, batch_size = 8, collate_fn = data_collator ) eval_dataloader = DataLoader( tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator )
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
optimizer = AdamW(model.parameters(),lr=5e-5) num_epochs = 3 num_train_steps = num_epochs * len(train_dataloader) lr_scheduler = get_scheduler( "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_train_steps, )
accelerator = Accelerator()
model,train_dataloader,eval_dataloader,optimizer = accelerator.prepare( model,train_dataloader,eval_dataloader,optimizer )
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
progress_bar = tqdm(range(num_train_steps)) model.train() for epoch in range(num_epochs): for batch in train_dataloader: output = model(**batch) loss = output.loss loss.backward() optimizer.step() lr_scheduler.step() optimizer.zero_grad() progress_bar.update(1)
metrics = evaluate.load("glue", "sst2") model.eval() for batch in eval_dataloader: batch = {k :v.to(device) for k,v in batch.items()} with torch.no_grad(): output = model(**batch) logits = output.logits pred = torch.argmax(logits,axis=-1) metrics.add_batch(predictions=accelerator.gather(pred),references=accelerator.gather(batch['labels'])) metrics.compute()
|