Checkpointing
A complete guide to saving and restoring training progress with
jd_update_checkpoint() and jd_get_last_checkpoint().
? Overview
The jd-worker library provides two checkpointing functions
that your entry script can call to save intermediate training state to
the central job server:
| Function | What it does |
|---|---|
jd_update_checkpoint(obj) |
Serialises any Python object with pickle and uploads it to the server as a versioned checkpoint file. |
jd_get_last_checkpoint() |
Downloads the most recent checkpoint for the current job and returns it as a Python object. Returns None if no checkpoint exists yet. |
Both functions are zero-configuration inside your script.
They automatically read JD_JOB_ID and JD_SERVER
from the environment — values that jd_worker_cli injects
before calling your script. You never need to pass job IDs or server
URLs explicitly.
from jd import jd_update_checkpoint, jd_get_last_checkpoint
# At the start of training — try to resume
ckpt = jd_get_last_checkpoint() # None on first run
# During training — save progress
jd_update_checkpoint({"epoch": 5, "model": model.state_dict()})
? Why Use Server-Side Checkpoints?
When you save a file locally with jd_job_dir(), it stays
on the worker's disk — it is lost if the node is preempted, rebooted,
or deallocated. Server-side checkpoints address this in two ways:
- Fault tolerance on preemptible hardware. Slurm jobs, cloud spot instances, and shared HPC nodes can be killed at any time. When a checkpoint is stored on the server, any subsequent restart of the same job — on any machine — can download it and continue from the last saved epoch instead of restarting from scratch.
- Centralised access across machines. In a heterogeneous sweep running on 10 different machines, the checkpoint for job #42 is always reachable from the server dashboard regardless of which machine ran it. You can view and download checkpoint files from the dashboard's job result panel.
1 API Reference
Serialises obj using Python's pickle and
uploads it to the job server. Each call creates a new versioned
file — previous versions are never overwritten, so you always have
a history of checkpoints:
checkpoint_v0_2026-05-20T14-31-22.pkl ← first call
checkpoint_v1_2026-05-20T15-10-05.pkl ← second call
checkpoint_v2_2026-05-20T16-02-44.pkl ← third call (this is "latest")
Parameter:
-
obj— any picklable Python object. Typically adictcontaining the model state dict, optimizer state, current epoch, and any other metadata needed to resume. Maximum serialised size: 100 MB.
Returns:
{"success": True, "filename": "checkpoint_v0_…", "version": 0, "size_bytes": …}
from jd import jd_update_checkpoint
result = jd_update_checkpoint({
"epoch": 5,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"val_acc": 0.913,
})
print(result["filename"]) # checkpoint_v0_2026-05-20T14-31-22.pkl
print(result["version"]) # 0
Downloads the highest-versioned checkpoint for the current job and deserialises it directly in memory — no temporary file is written to disk.
Returns: the Python object that was originally
passed to jd_update_checkpoint(), or None
if no checkpoint has been saved yet for this job.
Call this at the very beginning of your script,
before creating the model. If a checkpoint is returned, load it
immediately. If None, initialise from scratch.
from jd import jd_get_last_checkpoint
ckpt = jd_get_last_checkpoint()
if ckpt is not None:
# Resume from the last checkpoint
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt["epoch"] + 1
print(f"Resumed from epoch {ckpt['epoch']}, val_acc={ckpt.get('val_acc')}")
else:
# First run — start from scratch
start_epoch = 0
2 Usage Patterns
Pattern 1 — Basic periodic checkpoint
The simplest approach: save a checkpoint every N epochs. No resumption logic is needed — the checkpoint is just a safety net in case the job is interrupted.
import argparse
import torch
from jd import jd_update_checkpoint
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--epochs", type=int, default=50)
args = parser.parse_args()
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
CHECKPOINT_EVERY = 5 # save every 5 epochs
for epoch in range(args.epochs):
loss = train_one_epoch(model, optimizer)
if (epoch + 1) % CHECKPOINT_EVERY == 0:
jd_update_checkpoint({
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"loss": loss,
})
print(f"Epoch {epoch}: checkpoint saved.")
Pattern 2 — Full resume on restart
This is the recommended pattern for long-running jobs on preemptible hardware. The script checks for an existing checkpoint at startup and resumes seamlessly. If the node is killed and the job is requeued, it continues from the last saved epoch automatically.
import argparse, json
import torch
from jd import jd_job_dir, jd_upload, jd_update_checkpoint, jd_get_last_checkpoint
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float)
parser.add_argument("--layers", type=int)
parser.add_argument("--epochs", type=int)
args = parser.parse_args()
model = build_model(layers=args.layers)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# ── Resume from checkpoint if one exists ─────────────────────────────
ckpt = jd_get_last_checkpoint()
start_epoch = 0
if ckpt is not None:
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt["epoch"] + 1
print(f"Resumed from epoch {ckpt['epoch']}")
else:
print("No checkpoint found — starting from scratch.")
# ── Training loop ─────────────────────────────────────────────────────
best_acc = 0.0
for epoch in range(start_epoch, args.epochs):
loss, val_acc = train_one_epoch(model, optimizer)
best_acc = max(best_acc, val_acc)
# Checkpoint every 5 epochs
if (epoch + 1) % 5 == 0:
jd_update_checkpoint({
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"val_acc": val_acc,
})
# ── Save and upload final result ──────────────────────────────────────
job_dir = jd_job_dir()
job_dir.mkdir(parents=True, exist_ok=True)
result = {"lr": args.lr, "layers": args.layers, "best_acc": best_acc}
(job_dir / "result.json").write_text(json.dumps(result, indent=2))
jd_upload(job_dir / "result.json")
Pattern 3 — Save only the best model
Instead of checkpointing at fixed intervals, checkpoint only when validation accuracy improves. This keeps the number of checkpoint versions small and ensures the server always holds the best model seen so far.
from jd import jd_update_checkpoint, jd_get_last_checkpoint
ckpt = jd_get_last_checkpoint()
start_epoch = 0
best_val_acc = 0.0
if ckpt is not None:
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt["epoch"] + 1
best_val_acc = ckpt.get("best_val_acc", 0.0)
for epoch in range(start_epoch, args.epochs):
loss, val_acc = train_one_epoch(model, optimizer)
if val_acc > best_val_acc:
best_val_acc = val_acc
jd_update_checkpoint({
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"best_val_acc": best_val_acc,
})
print(f"Epoch {epoch}: new best val_acc={val_acc:.4f} — checkpoint saved.")
Pattern 4 — Non-PyTorch objects
Checkpoints are not limited to PyTorch models. Any picklable Python object works — scikit-learn pipelines, XGBoost models, NumPy arrays, plain dictionaries, or custom classes. The pattern is identical.
from sklearn.ensemble import GradientBoostingClassifier
from jd import jd_update_checkpoint, jd_get_last_checkpoint
# Try to resume a partially-fitted GBM
ckpt = jd_get_last_checkpoint()
model = ckpt["model"] if ckpt else GradientBoostingClassifier(n_estimators=500)
# … fit additional stages or continue fitting …
model.fit(X_train, y_train)
jd_update_checkpoint({"model": model, "score": model.score(X_val, y_val)})
pickle format. Any
object that can be pickled — including third-party models, custom
dataclasses, and nested containers — can be checkpointed. Objects
that contain open file handles, database connections, or OS-level
resources cannot be pickled and will raise an error.
3 Size Limits & Practical Tips
| Topic | Detail |
|---|---|
| Maximum checkpoint size | 100 MB (serialised pickle bytes). Use jd_upload() for larger artefacts. |
| Versioning | Every call to jd_update_checkpoint() creates a new version. Old versions are kept on the server and are visible in the dashboard under the job's result files. |
What jd_get_last_checkpoint() returns |
Always the highest-numbered version (not necessarily the last by time). On first run it returns None. |
| Frequency | Checkpoint every 5–10 epochs for deep-learning jobs. More frequent is safer but slower; for very fast epochs, checkpointing every N minutes by wall-clock time may be preferable. |
Complementary to jd_upload() |
Checkpoints are for intermediate resumption state. Use jd_upload() at the end of training to store final model weights, result files, plots, and any artefact you want to download later. |
Checkpoint vs. local file save — when to use which
jd_update_checkpoint() | Local file via jd_job_dir() | |
|---|---|---|
| Survives node preemption | ✅ Yes — stored on the server | ❌ No — lost if local disk is wiped |
| Accessible from dashboard | ✅ Yes | Only if you also call jd_upload() |
| Max size | 100 MB per checkpoint | Limited by local disk only |
| Speed | Network upload (~seconds) | Local I/O (~milliseconds) |
| Best for | Resumption state (model + optimizer) | Final artefacts, large models, datasets |