Using Python Profiling Tools on Berzelius

1. Introduction

Efficient code is essential when working with high-performance computing (HPC) resources like Berzelius. Whether you’re training deep learning models, analyzing large datasets, or running scientific simulations, understanding the performance characteristics of your Python code can help you reduce runtime, optimize resource usage, and debug bottlenecks.

We expect jobs properly utilizing the GPUs on Berzelius. Particularly inefficient jobs will be automatically terminated. Please read Berzelius GPU Usage Efficiency Policy for more details.

2. py-spy

py-spy is a sampling profiler for Python that allows you to analyze running processes without modifying your code. It works by attaching to a Python process and sampling its call stack at regular intervals to determine which functions consume the most CPU time.

2.1 Installing py-spy

Within a Conda environment:

# Using mamba
mamba install py-spy
# Using pip
pip install py-spy

2.2 Usage

2.2.1 Record a Flame Graph

A flame graph is a visual representation of profiling data that shows which functions consume the most time in your program — and how they relate to each other in the call stack. It’s one of the most intuitive ways to analyze performance bottlenecks in Python (and many other languages).

You can generate a flame graph with:

py-spy record -o flame-graph-mnist.svg -- python examples/mnist/main.py --epochs 10

Once generated, open the SVG in a browser: py-spy Flame Graph Example

2.2.2 Viewing a Live Function Call Summary

To display a top-like live view of which functions are consuming the most CPU time in your Python script, use:

py-spy top -- python examples/mnist/main.py --epochs 10

This command will launch your script and simultaneously show a real-time function usage summary, updated every second.

A top-like Live View Example

Legend:

  • %Own: (% of time currently spent in the function)
  • %Total: (% of time currently in the function and its children)
  • OwnTime: (Overall time spent in the function)
  • TotalTime: (Overall time spent in the function and its children)

3. line_profiler

line_profiler is a deterministic, line-by-line profiler for Python. Unlike py-spy, which samples running code, line_profiler instruments specific functions and provides detailed timing information for every single line inside those functions.

This makes it ideal for targeted performance analysis — for example, if you suspect a particular loop or function is a bottleneck.

3.1 Installing line_profiler

Within a Conda environment:

# Using mamba
mamba install line_profiler
# Using pip
pip install line_profiler

3.2 Usage

3.2.1 Step 1: Mark Functions to Profile

Use the @profile decorator on any function you want to analyze:

from line_profiler import profile
@profile
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break

3.2.2 Step 2: Run Your Script

This will execute the script and generate a .lprof file, then print a detailed line-by-line report to the terminal.

LINE_PROFILE=1 python examples/mnist/main.py --epochs 10

3.2.3 Step 3: Interpreting the Output

To view details run:

python -m line_profiler -rtmz profile_output.lprof

A typical line_profiler output looks like:

Timer unit: 1e-06 s

Total time: 98.2199 s
File: /proj/nsc_testing/xuan/examples/mnist/main.py
Function: train at line 37

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    37                                           @profile
    38                                           def train(args, model, device, train_loader, optimizer, epoch):
    39        10       1136.9    113.7      0.0      model.train()
    40      9390   84474146.5   8996.2     86.0      for batch_idx, (data, target) in enumerate(train_loader):
    41      9380     601675.8     64.1      0.6          data, target = data.to(device), target.to(device)
    42      9380     740105.6     78.9      0.8          optimizer.zero_grad()
    43      9380    3680331.0    392.4      3.7          output = model(data)
    44      9380     257046.2     27.4      0.3          loss = F.nll_loss(output, target)
    45      9380    4744006.5    505.8      4.8          loss.backward()
    46      9380    3406183.2    363.1      3.5          optimizer.step()
    47      9380       7312.5      0.8      0.0          if batch_idx % args.log_interval == 0:
    48      1880      15025.2      8.0      0.0              print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
    49       940       6573.5      7.0      0.0                  epoch, batch_idx * len(data), len(train_loader.dataset),
    50       940     285663.0    303.9      0.3                  100. * batch_idx / len(train_loader), loss.item()))
    51       940        647.9      0.7      0.0              if args.dry_run:
    52                                                           break

 98.22 seconds - /proj/nsc_testing/xuan/examples/mnist/main.py:37 - train

Focus on lines that take a high %Time — these are your optimization targets.


User Area

User support

Guides, documentation and FAQ.

Getting access

Applying for projects and login accounts.

System status

Everything OK!

No reported problems

Self-service

SUPR
NSC Express