Efficient code is essential when working with high-performance computing (HPC) resources like Berzelius. Whether you’re training deep learning models, analyzing large datasets, or running scientific simulations, understanding the performance characteristics of your Python code can help you reduce runtime, optimize resource usage, and debug bottlenecks.
We expect jobs properly utilizing the GPUs on Berzelius. Particularly inefficient jobs will be automatically terminated. Please read Berzelius GPU Usage Efficiency Policy for more details.
py-spy
is a sampling profiler for Python that allows you to analyze running processes without modifying your code. It works by attaching to a Python process and sampling its call stack at regular intervals to determine which functions consume the most CPU time.
Within a Conda environment:
# Using mamba
mamba install py-spy
# Using pip
pip install py-spy
A flame graph is a visual representation of profiling data that shows which functions consume the most time in your program — and how they relate to each other in the call stack. It’s one of the most intuitive ways to analyze performance bottlenecks in Python (and many other languages).
You can generate a flame graph with:
py-spy record -o flame-graph-mnist.svg -- python examples/mnist/main.py --epochs 10
Once generated, open the SVG in a browser: py-spy Flame Graph Example
To display a top-like live view of which functions are consuming the most CPU time in your Python script, use:
py-spy top -- python examples/mnist/main.py --epochs 10
This command will launch your script and simultaneously show a real-time function usage summary, updated every second.
Legend:
line_profiler
is a deterministic, line-by-line profiler for Python. Unlike py-spy, which samples running code, line_profiler instruments specific functions and provides detailed timing information for every single line inside those functions.
This makes it ideal for targeted performance analysis — for example, if you suspect a particular loop or function is a bottleneck.
Within a Conda environment:
# Using mamba
mamba install line_profiler
# Using pip
pip install line_profiler
Use the @profile
decorator on any function you want to analyze:
from line_profiler import profile
@profile
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break
This will execute the script and generate a .lprof
file, then print a detailed line-by-line report to the terminal.
LINE_PROFILE=1 python examples/mnist/main.py --epochs 10
To view details run:
python -m line_profiler -rtmz profile_output.lprof
A typical line_profiler output looks like:
Timer unit: 1e-06 s Total time: 98.2199 s File: /proj/nsc_testing/xuan/examples/mnist/main.py Function: train at line 37 Line # Hits Time Per Hit % Time Line Contents ============================================================== 37 @profile 38 def train(args, model, device, train_loader, optimizer, epoch): 39 10 1136.9 113.7 0.0 model.train() 40 9390 84474146.5 8996.2 86.0 for batch_idx, (data, target) in enumerate(train_loader): 41 9380 601675.8 64.1 0.6 data, target = data.to(device), target.to(device) 42 9380 740105.6 78.9 0.8 optimizer.zero_grad() 43 9380 3680331.0 392.4 3.7 output = model(data) 44 9380 257046.2 27.4 0.3 loss = F.nll_loss(output, target) 45 9380 4744006.5 505.8 4.8 loss.backward() 46 9380 3406183.2 363.1 3.5 optimizer.step() 47 9380 7312.5 0.8 0.0 if batch_idx % args.log_interval == 0: 48 1880 15025.2 8.0 0.0 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 49 940 6573.5 7.0 0.0 epoch, batch_idx * len(data), len(train_loader.dataset), 50 940 285663.0 303.9 0.3 100. * batch_idx / len(train_loader), loss.item())) 51 940 647.9 0.7 0.0 if args.dry_run: 52 break 98.22 seconds - /proj/nsc_testing/xuan/examples/mnist/main.py:37 - train
Focus on lines that take a high %Time
— these are your optimization targets.
Guides, documentation and FAQ.
Applying for projects and login accounts.