Spot instances

What is the spot instance?

Spot instances are the type of AWS instance to run an operation on top of that. SavviHub supports Amazon EC2 Spot Instances on Amazon Elastic Kubernetes Service (EKS) to run an experiment. Spot instances are attractive in terms of price and performance compared to On-Demand instances, especially on stateless and fault-tolerant container runs. According to AWS Docs, users can save up to 90% of prices by using a spot instance compared to an On-Demand instance.

Disclaimer

Since the spot instances are subject to interruption, the claimed spot instances will be suspended with two minutes of notice if the resources are needed elsewhere. Thus, saving and loading models for each epoch is recommended. Fortunately, most ML toolkits, such as Fairseq and Detectron2, provide the checkpointing model feature to keep the best-performed model. For the checkpoint tutorial, go saving and loading models documents of PyTorch and TensorFlow:

How to use a spot instance?

Please refer to the SavviHub examples code in the GitHub repository.

1. Save checkpoints

While training the model, you need to save the model every specific period. The followings are PyTorch and Keras codes that compare validation accuracy for each epoch and save the best model with epoch values. Noted that the example code saves checkpoints with the epoch value so that we can load this value as a start_epoch value.
PyTorch
Keras
1
import torch
2
3
def save_checkpoint(state, is_best, filename):
4
if is_best:
5
print("=> Saving a new best")
6
torch.save(state, filename)
7
else:
8
print("=> Validation Accuracy did not improve")
9
10
11
for epoch in range(epochs):
12
train(...)
13
test_accuracy = test(...)
14
15
test_accuracy = torch.FloatTensor([test_accuracy])
16
is_best = bool(test_accuracy.numpy() > best_accuracy.numpy())
17
best_accuracy = torch.FloatTensor(
18
max(test_accuracy.numpy(), best_accuracy.numpy()))
19
save_checkpoint({
20
'epoch': start_epoch + epoch + 1,
21
'state_dict': model.state_dict(),
22
'best_accuracy': best_accuracy,
23
}, is_best, checkpoint_file_path)
Copied!
1
from savvihub.keras import SavviHubCallback
2
from keras.callbacks import ModelCheckpoint
3
import os
4
5
checkpoint_path = os.path.join(args.checkpoint_path, 'checkpoints-{epoch:04d}.ckpt')
6
checkpoint_dir = os.path.dirname(checkpoint_path)
7
8
checkpoint_callback = ModelCheckpoint(
9
checkpoint_path,
10
monitor='val_accuracy',
11
verbose=1,
12
save_weights_only=True,
13
mode='max',
14
save_freq=args.save_model_freq,
15
)
16
17
# Compile model
18
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
19
model.compile(optimizer='adam',
20
loss=loss_fn,
21
metrics=['accuracy'])
22
23
model.save_weights(checkpoint_path.format(epoch=0))
24
25
model.fit(x_train, y_train,
26
batch_size=args.batch_size,
27
validation_data=(x_val, y_val),
28
epochs=args.epochs,
29
callbacks=[
30
SavviHubCallback(
31
data_type='image',
32
validation_data=(x_val, y_val),
33
num_images=5,
34
start_epoch=start_epoch,
35
save_image=args.save_image,
36
),
37
checkpoint_callback,
38
])
Copied!

2. Load checkpoints

When spot instance interruption occurs, the code is executed again from the beginning. Therefore, you need to write a code that loads the saved checkpoint before training.
PyTorch
Keras
1
import torch
2
import os
3
4
def load_checkpoint(checkpoint_file_path):
5
print(f"=> Loading checkpoint '{checkpoint_file_path}' ...")
6
if device == 'cuda':
7
checkpoint = torch.load(checkpoint_file_path)
8
else:
9
checkpoint = torch.load(checkpoint_file_path,
10
map_location=lambda storage, loc: storage)
11
model.load_state_dict(checkpoint.get('state_dict'))
12
print(f"=> Loaded checkpoint (trained for {checkpoint.get('epoch')} epochs)")
13
return checkpoint.get('epoch'), checkpoint.get('best_accuracy')
14
15
16
if os.path.exists(args.checkpoint_path) and os.path.isfile(checkpoint_file_path):
17
start_epoch, best_accuracy = load_checkpoint(checkpoint_file_path)
18
else:
19
print("=> No checkpoint has found! train from scratch")
20
start_epoch, best_accuracy = 0, torch.FloatTensor([0])
21
if not os.path.exists(args.checkpoint_path):
22
print(f" [*] Make directories : {args.checkpoint_path}")
23
os.makedirs(args.checkpoint_path)
Copied!
1
import os
2
import tensorflow as tf
3
4
def parse_epoch(file_path):
5
return int(os.path.splitext(os.path.basename(file_path))[0].split('-')[1])
6
7
8
checkpoint_path = os.path.join(args.checkpoint_path, 'checkpoints-{epoch:04d}.ckpt')
9
checkpoint_dir = os.path.dirname(checkpoint_path)
10
if os.path.exists(checkpoint_dir) and len(os.listdir(checkpoint_dir)) > 0:
11
latest = tf.train.latest_checkpoint(checkpoint_dir)
12
print(f"=> Loading checkpoint '{latest}' ...")
13
model.load_weights(latest)
14
start_epoch = parse_epoch(latest)
15
print(f'start_epoch:{start_epoch}')
16
else:
17
start_epoch = 0
18
if not os.path.exists(args.checkpoint_path):
19
print(f" [*] Make directories : {args.checkpoint_path}")
20
os.makedirs(args.checkpoint_path)
Copied!
The start_epoch value is a useful workaround to logging metrics to the SavviHub server as follows. Otherwise, the metrics graph might be crashed due to the spot instance interruption.
PyTorch
Keras
1
import savvihub
2
3
def train(...):
4
...
5
savvihub.log(
6
step=epoch+start_epoch+1,
7
row={'loss': loss.item()}
8
)
Copied!
1
from savvihub.keras import SavviHubCallback
2
3
model.fit(...,
4
callbacks=[SavviHubCallback(
5
...,
6
start_epoch=start_epoch,
7
...,
8
)]
9
)
Copied!

3. Use the spot instance option

To use a spot instance on SavviHub, click the "Use Spot Instance" check box then the spot instance type will appear. We also put the postfix *.spot for every spot instance resource type. More resource types will be added to the resource type in the future.
Tick the spot instance option checkbox to list spot instance resource types
Last modified 5mo ago