Wednesday, August 3, 2022

PyTorch checkpoint state dict

Sadly RobertaModel doesn't give you the full picture:

>>> from fairseq.models.roberta import RobertaModel

>>> roberta = RobertaModel.from_pretrained('./roberta.base/')

One has to load from the checkpoint themself:

>>> path = './roberta.base/model.pt'

>>> with open(path, 'rb') as f:

>>>   state = torch.load(f, map_location=torch.device("cpu"))

>>> state['args']

Namespace(no_progress_bar=False, log_interval=25, log_format='json', tbmf_wrapper=False, seed=1, cpu=False, fp16=True, memory_efficient_fp16=True, fp16_init_scale=4, fp16_scale_window=128, fp16_scale_tolerance=0.0, min_loss_scale=0.0001, threshold_loss_scale=1.0, user_dir=None, criterion='masked_lm', tokenizer=None, bpe=None, optimizer='adam', lr_scheduler='polynomial_decay', task='masked_lm', num_workers=2, skip_invalid_size_inputs_valid_test=True, max_tokens=999999, max_sentences=16, required_batch_size_multiple=1, dataset_impl='mmap', train_subset='train', valid_subset='valid', validate_interval=1, disable_validation=False, only_validate=False, max_sentences_valid=16, curriculum=0, distributed_world_size=512, distributed_rank=0, distributed_backend='nccl', distributed_port=19812, device_id=0, distributed_no_spawn=False, ddp_backend='c10d', bucket_cap_mb=200, fix_batches_to_gpus=False, find_unused_parameters=True, arch='roberta_base', max_epoch=0, max_update=500000, clip_norm=0.0, sentence_avg=False, update_freq=[1], lr=[0.0006], min_lr=-1, use_bmuf=False, global_sync_iter=10, restore_file='checkpoint_last.pt', reset_dataloader=True, reset_lr_scheduler=False, reset_meters=False, reset_optimizer=False, optimizer_overrides='{}', save_interval=1, save_interval_updates=2000, keep_interval_updates=-1, keep_last_epochs=-1, no_save=False, no_epoch_checkpoints=True, no_last_checkpoints=False, no_save_optimizer_state=False, best_checkpoint_metric='loss', maximize_best_checkpoint_metric=False, adam_betas='(0.9, 0.98)', adam_eps=1e-06, weight_decay=0.01, force_anneal=None, warmup_updates=24000, end_learning_rate=0.0, power=1.0, total_num_update=500000, sample_break_mode='complete', tokens_per_sample=512, mask_prob=0.15, leave_unmasked_prob=0.1, random_token_prob=0.1, activation_fn='gelu', dropout=0.1, attention_dropout=0.1, encoder_embed_dim=768, encoder_layers=12, encoder_attention_heads=12, encoder_ffn_embed_dim=3072, pooler_activation_fn='tanh', max_positions=512, activation_dropout=0.0)

No comments: