Checkpointing

A complete guide to saving and restoring training progress with jd_update_checkpoint() and jd_get_last_checkpoint().

? Overview

The jd-worker library provides two checkpointing functions that your entry script can call to save intermediate training state to the central job server:

FunctionWhat it does
jd_update_checkpoint(obj) Serialises any Python object with pickle and uploads it to the server as a versioned checkpoint file.
jd_get_last_checkpoint() Downloads the most recent checkpoint for the current job and returns it as a Python object. Returns None if no checkpoint exists yet.

Both functions are zero-configuration inside your script. They automatically read JD_JOB_ID and JD_SERVER from the environment — values that jd_worker_cli injects before calling your script. You never need to pass job IDs or server URLs explicitly.

Python
from jd import jd_update_checkpoint, jd_get_last_checkpoint

# At the start of training — try to resume
ckpt = jd_get_last_checkpoint()    # None on first run

# During training — save progress
jd_update_checkpoint({"epoch": 5, "model": model.state_dict()})

? Why Use Server-Side Checkpoints?

When you save a file locally with jd_job_dir(), it stays on the worker's disk — it is lost if the node is preempted, rebooted, or deallocated. Server-side checkpoints address this in two ways:

  1. Fault tolerance on preemptible hardware. Slurm jobs, cloud spot instances, and shared HPC nodes can be killed at any time. When a checkpoint is stored on the server, any subsequent restart of the same job — on any machine — can download it and continue from the last saved epoch instead of restarting from scratch.
  2. Centralised access across machines. In a heterogeneous sweep running on 10 different machines, the checkpoint for job #42 is always reachable from the server dashboard regardless of which machine ran it. You can view and download checkpoint files from the dashboard's job result panel.
💡 For jobs that run quickly (under a few minutes), checkpointing is not necessary. Focus on jobs that take tens of minutes or more — those are the ones where an unexpected interruption is costly.

1 API Reference

function jd.jd_update_checkpoint(obj) → dict

Serialises obj using Python's pickle and uploads it to the job server. Each call creates a new versioned file — previous versions are never overwritten, so you always have a history of checkpoints:

path
checkpoint_v0_2026-05-20T14-31-22.pkl   ← first call
checkpoint_v1_2026-05-20T15-10-05.pkl   ← second call
checkpoint_v2_2026-05-20T16-02-44.pkl   ← third call  (this is "latest")

Parameter:

  • obj — any picklable Python object. Typically a dict containing the model state dict, optimizer state, current epoch, and any other metadata needed to resume. Maximum serialised size: 100 MB.

Returns: {"success": True, "filename": "checkpoint_v0_…", "version": 0, "size_bytes": …}

Python
from jd import jd_update_checkpoint

result = jd_update_checkpoint({
    "epoch":     5,
    "model":     model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "val_acc":   0.913,
})
print(result["filename"])   # checkpoint_v0_2026-05-20T14-31-22.pkl
print(result["version"])    # 0
function jd.jd_get_last_checkpoint() → object | None

Downloads the highest-versioned checkpoint for the current job and deserialises it directly in memory — no temporary file is written to disk.

Returns: the Python object that was originally passed to jd_update_checkpoint(), or None if no checkpoint has been saved yet for this job.

Call this at the very beginning of your script, before creating the model. If a checkpoint is returned, load it immediately. If None, initialise from scratch.

Python
from jd import jd_get_last_checkpoint

ckpt = jd_get_last_checkpoint()

if ckpt is not None:
    # Resume from the last checkpoint
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    start_epoch = ckpt["epoch"] + 1
    print(f"Resumed from epoch {ckpt['epoch']}, val_acc={ckpt.get('val_acc')}")
else:
    # First run — start from scratch
    start_epoch = 0

2 Usage Patterns

Pattern 1 — Basic periodic checkpoint

The simplest approach: save a checkpoint every N epochs. No resumption logic is needed — the checkpoint is just a safety net in case the job is interrupted.

Python
import argparse
import torch
from jd import jd_update_checkpoint

parser = argparse.ArgumentParser()
parser.add_argument("--lr",     type=float, default=0.01)
parser.add_argument("--epochs", type=int,   default=50)
args = parser.parse_args()

model     = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

CHECKPOINT_EVERY = 5   # save every 5 epochs

for epoch in range(args.epochs):
    loss = train_one_epoch(model, optimizer)

    if (epoch + 1) % CHECKPOINT_EVERY == 0:
        jd_update_checkpoint({
            "epoch":     epoch,
            "model":     model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "loss":      loss,
        })
        print(f"Epoch {epoch}: checkpoint saved.")

Pattern 2 — Full resume on restart

This is the recommended pattern for long-running jobs on preemptible hardware. The script checks for an existing checkpoint at startup and resumes seamlessly. If the node is killed and the job is requeued, it continues from the last saved epoch automatically.

Python
import argparse, json
import torch
from jd import jd_job_dir, jd_upload, jd_update_checkpoint, jd_get_last_checkpoint

parser = argparse.ArgumentParser()
parser.add_argument("--lr",     type=float)
parser.add_argument("--layers", type=int)
parser.add_argument("--epochs", type=int)
args = parser.parse_args()

model     = build_model(layers=args.layers)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

# ── Resume from checkpoint if one exists ─────────────────────────────
ckpt        = jd_get_last_checkpoint()
start_epoch = 0
if ckpt is not None:
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    start_epoch = ckpt["epoch"] + 1
    print(f"Resumed from epoch {ckpt['epoch']}")
else:
    print("No checkpoint found — starting from scratch.")

# ── Training loop ─────────────────────────────────────────────────────
best_acc = 0.0
for epoch in range(start_epoch, args.epochs):
    loss, val_acc = train_one_epoch(model, optimizer)
    best_acc      = max(best_acc, val_acc)

    # Checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        jd_update_checkpoint({
            "epoch":     epoch,
            "model":     model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "val_acc":   val_acc,
        })

# ── Save and upload final result ──────────────────────────────────────
job_dir = jd_job_dir()
job_dir.mkdir(parents=True, exist_ok=True)
result  = {"lr": args.lr, "layers": args.layers, "best_acc": best_acc}
(job_dir / "result.json").write_text(json.dumps(result, indent=2))
jd_upload(job_dir / "result.json")

Pattern 3 — Save only the best model

Instead of checkpointing at fixed intervals, checkpoint only when validation accuracy improves. This keeps the number of checkpoint versions small and ensures the server always holds the best model seen so far.

Python
from jd import jd_update_checkpoint, jd_get_last_checkpoint

ckpt         = jd_get_last_checkpoint()
start_epoch  = 0
best_val_acc = 0.0

if ckpt is not None:
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    start_epoch  = ckpt["epoch"] + 1
    best_val_acc = ckpt.get("best_val_acc", 0.0)

for epoch in range(start_epoch, args.epochs):
    loss, val_acc = train_one_epoch(model, optimizer)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        jd_update_checkpoint({
            "epoch":        epoch,
            "model":        model.state_dict(),
            "optimizer":    optimizer.state_dict(),
            "best_val_acc": best_val_acc,
        })
        print(f"Epoch {epoch}: new best val_acc={val_acc:.4f} — checkpoint saved.")

Pattern 4 — Non-PyTorch objects

Checkpoints are not limited to PyTorch models. Any picklable Python object works — scikit-learn pipelines, XGBoost models, NumPy arrays, plain dictionaries, or custom classes. The pattern is identical.

Python
from sklearn.ensemble import GradientBoostingClassifier
from jd import jd_update_checkpoint, jd_get_last_checkpoint

# Try to resume a partially-fitted GBM
ckpt  = jd_get_last_checkpoint()
model = ckpt["model"] if ckpt else GradientBoostingClassifier(n_estimators=500)

# … fit additional stages or continue fitting …
model.fit(X_train, y_train)

jd_update_checkpoint({"model": model, "score": model.score(X_val, y_val)})
ℹ️ Checkpoints use Python's standard pickle format. Any object that can be pickled — including third-party models, custom dataclasses, and nested containers — can be checkpointed. Objects that contain open file handles, database connections, or OS-level resources cannot be pickled and will raise an error.

3 Size Limits & Practical Tips

TopicDetail
Maximum checkpoint size 100 MB (serialised pickle bytes). Use jd_upload() for larger artefacts.
Versioning Every call to jd_update_checkpoint() creates a new version. Old versions are kept on the server and are visible in the dashboard under the job's result files.
What jd_get_last_checkpoint() returns Always the highest-numbered version (not necessarily the last by time). On first run it returns None.
Frequency Checkpoint every 5–10 epochs for deep-learning jobs. More frequent is safer but slower; for very fast epochs, checkpointing every N minutes by wall-clock time may be preferable.
Complementary to jd_upload() Checkpoints are for intermediate resumption state. Use jd_upload() at the end of training to store final model weights, result files, plots, and any artefact you want to download later.

Checkpoint vs. local file save — when to use which

jd_update_checkpoint()Local file via jd_job_dir()
Survives node preemption ✅ Yes — stored on the server ❌ No — lost if local disk is wiped
Accessible from dashboard ✅ Yes Only if you also call jd_upload()
Max size 100 MB per checkpoint Limited by local disk only
Speed Network upload (~seconds) Local I/O (~milliseconds)
Best for Resumption state (model + optimizer) Final artefacts, large models, datasets