LLVM 15 Debugging

LLVM 15 and 20 produce different results with fastmath=True when inlining get_inverse_doppler_factor into line_scatter_event. This causes 1-2 ULP differences that cascade.

[1]:
import os, struct, re
os.environ["NUMBA_NUM_THREADS"] = "1"

import numpy as np
import llvmlite.binding as binding
print(f"LLVM {binding.llvm_version_info}")

def f2h(f):
    return struct.pack('!d', float(f)).hex()

def ulp_diff(a, b):
    ai = struct.unpack('!q', struct.pack('!d', a))[0]
    bi = struct.unpack('!q', struct.pack('!d', b))[0]
    return abs(ai - bi)
LLVM (15, 0, 7)
[2]:
# Run minimal sim to trigger JIT
import tardis
from tardis.io.configuration.config_reader import Configuration
from tardis.simulation import Simulation

config = Configuration.from_yaml(os.path.join(
    os.path.dirname(tardis.__file__),
    "io", "configuration", "tests", "data", "tardis_configv1_verysimple.yml",
))
config["plasma"]["line_interaction_type"] = "downbranch"
config["montecarlo"]["no_of_packets"] = 100
config["montecarlo"]["last_no_of_packets"] = 100
config["montecarlo"]["no_of_virtual_packets"] = 0
config.atom_data = os.environ.get(
    "TARDIS_ATOM_DATA",
    os.path.expanduser("~/Downloads/tardis-data/kurucz_cd23_chianti_H_He.h5"),
)
sim = Simulation.from_config(config)
sim.iterate(no_of_packets=100, no_of_virtual_packets=0)
print("JIT done")
Initializing tabulator and plotly panel extensions for widgets to work
Number of density points larger than number of shells. Assuming inner point irrelevant
model_isotope_time_0 is not set in the configuration. Isotopic mass fractions will not be decayed and is assumed to be correct for the time_explosion. THIS IS NOT RECOMMENDED!
/home/aryaatharva18/tardis-main/tardis/tardis/transport/montecarlo/modes/classic/montecarlo_transport.py:161: NumbaTypeSafetyWarning:

unsafe cast from uint64 to int64. Precision may be lost.

JIT done

IR for line_scatter_event

After inlining, LLVM 15 keeps fdiv 1.0, x (reciprocal) then fmul. LLVM 20 folds it into a single fdiv y, x. Look for fdiv instructions below.

[3]:
from tardis.transport.montecarlo.interaction_event_callers import line_scatter_event

sig = line_scatter_event.signatures[0]
ir = line_scatter_event.inspect_llvm(sig)

arith_re = re.compile(
    r'^\s+(%.+)\s*=\s*(fmul|fadd|fsub|fdiv|fneg|call.*(?:sqrt|fma))\s+(.+)$',
    re.MULTILINE,
)
ops = arith_re.findall(ir)
fast_n = sum(1 for _, _, r in ops if r.startswith('fast'))
reassoc_n = sum(1 for _, _, r in ops if r.startswith('reassoc'))

print(f"FP ops: {len(ops)} total, {fast_n} fast, {reassoc_n} reduced-flags")
print()
print("fdiv instructions:")
for name, op, rest in ops:
    if op == 'fdiv':
        print(f"  {name} = {op} {rest}")
FP ops: 52 total, 52 fast, 0 reduced-flags

fdiv instructions:
  %.30.i  = fdiv fast double %1, %arg.time_explosion
  %.15.i.i  = fdiv fast double %.7.i.i, %sqrt.i.i
  %.22.i.i  = fdiv fast double %.126, %.9.i.i
  %.29.i.i  = fdiv fast double %.23.i.i, %.25.i.i
  %.30.i2  = fdiv fast double %23, %arg.time_explosion
  %.11.i.i  = fdiv fast double 1.000000e+00, %.7.i.i5
  %.15.i.i9  = fdiv fast double %.7.i2.i, %sqrt.i.i8
  %.30.i.i  = fdiv fast double %2, %arg.time_explosion
  %.11.i.i.i  = fdiv fast double 1.000000e+00, %.7.i.i.i
  %.15.i.i.i  = fdiv fast double %.7.i2.i.i, %sqrt.i.i.i
  %.22.i.i  = fdiv fast double %.54.i, %.9.i.i
  %.29.i.i  = fdiv fast double %.23.i.i, %.25.i.i
[4]:
# Show context around each fdiv
lines = ir.split('\n')
for i, line in enumerate(lines):
    if 'fdiv' in line and '=' in line:
        start, end = max(0, i - 2), min(len(lines), i + 3)
        print(f"--- IR line {i+1} ---")
        for j in range(start, end):
            m = ">>>" if j == i else "   "
            print(f"{m} {j+1:5d}: {lines[j]}")
        print()
--- IR line 103 ---
      101:   %.55 = zext i1 %0 to i8
      102:   %1 = fmul fast double %.39, 0x3DC2567F4ED09FE8
>>>   103:   %.30.i = fdiv fast double %1, %arg.time_explosion
      104:   %.6.i.i = fmul fast double %.30.i, %.48
      105:   %.7.i.i = fsub fast double 1.000000e+00, %.6.i.i

--- IR line 112 ---
      110:   %.10.i.i = fsub fast double 1.000000e+00, %.8.i.i
      111:   %sqrt.i.i = tail call fast double @llvm.sqrt.f64(double %.10.i.i)
>>>   112:   %.15.i.i = fdiv fast double %.7.i.i, %sqrt.i.i
      113:   br label %B0.endif
      114:

--- IR line 143 ---
      141: B182.endif.i:                                     ; preds = %B36.endif.i
      142:   %.9.i.i = fmul fast double %arg.time_explosion, 0x421BEB9BF3A00000
>>>   143:   %.22.i.i = fdiv fast double %.126, %.9.i.i
      144:   %.23.i.i = fadd fast double %.31.i, %.22.i.i
      145:   %.24.i.i = fmul fast double %.31.i, %.22.i.i

--- IR line 147 ---
      145:   %.24.i.i = fmul fast double %.31.i, %.22.i.i
      146:   %.25.i.i = fadd fast double %.24.i.i, 1.000000e+00
>>>   147:   %.29.i.i = fdiv fast double %.23.i.i, %.25.i.i
      148:   %5 = bitcast { double, double, double, double, i64, i64, i64, i64, i64 }* %arg.r_packet.1 to i8*
      149:   %sunkaddr = getelementptr inbounds i8, i8* %5, i64 8

--- IR line 262 ---
      260:   %.126 = load double, double* %14, align 8
      261:   %23 = fmul fast double %.126, 0x3DC2567F4ED09FE8
>>>   262:   %.30.i2 = fdiv fast double %23, %arg.time_explosion
      263:   %.6.i.i3 = fmul fast double %.31.i, %.30.i2
      264:   br i1 %15, label %B0.endif.endif.endif, label %B0.endif.endif.endif.thread

--- IR line 268 ---
      266: B0.endif.endif.endif:                             ; preds = %B0.endif.endif
      267:   %.7.i.i5 = fsub fast double 1.000000e+00, %.6.i.i3
>>>   268:   %.11.i.i = fdiv fast double 1.000000e+00, %.7.i.i5
      269:   %.173 = getelementptr inbounds { double, double, double, double, i64, i64, i64, i64, i64 }, { double, double, double, double, i64, i64, i64, i64, i64 }* %arg.r_packet.1, i64 0, i32 3
      270:   %.174 = load double, double* %.173, align 8

--- IR line 282 ---
      280:   %.10.i.i7 = fsub fast double 1.000000e+00, %.8.i.i6
      281:   %sqrt.i.i8 = tail call fast double @llvm.sqrt.f64(double %.10.i.i7)
>>>   282:   %.15.i.i9 = fdiv fast double %.7.i2.i, %sqrt.i.i8
      283:   %.17335 = getelementptr inbounds { double, double, double, double, i64, i64, i64, i64, i64 }, { double, double, double, double, i64, i64, i64, i64, i64 }* %arg.r_packet.1, i64 0, i32 3
      284:   %.17436 = load double, double* %.17335, align 8

--- IR line 714 ---
      712:   %.63.i = load double, double* %.62.i, align 8, !noalias !30
      713:   %2 = fmul fast double %.54.i, 0x3DC2567F4ED09FE8
>>>   714:   %.30.i.i = fdiv fast double %2, %arg.time_explosion
      715:   %.6.i.i.i = fmul fast double %.30.i.i, %.63.i
      716:   br i1 %1, label %B60.endif.i.i, label %B82.endif.i.i

--- IR line 720 ---
      718: B60.endif.i.i:                                    ; preds = %B528
      719:   %.7.i.i.i = fsub fast double 1.000000e+00, %.6.i.i.i
>>>   720:   %.11.i.i.i = fdiv fast double 1.000000e+00, %.7.i.i.i
      721:   br label %B36.endif.i
      722:

--- IR line 728 ---
      726:   %.10.i.i.i = fsub fast double 1.000000e+00, %.8.i.i.i
      727:   %sqrt.i.i.i = tail call fast double @llvm.sqrt.f64(double %.10.i.i.i)
>>>   728:   %.15.i.i.i = fdiv fast double %.7.i2.i.i, %sqrt.i.i.i
      729:   br label %B36.endif.i
      730:

--- IR line 758 ---
      756: B182.endif.i:                                     ; preds = %B36.endif.i
      757:   %.9.i.i = fmul fast double %arg.time_explosion, 0x421BEB9BF3A00000
>>>   758:   %.22.i.i = fdiv fast double %.54.i, %.9.i.i
      759:   %.23.i.i = fadd fast double %.22.i.i, %.63.i
      760:   %.24.i.i = fmul fast double %.22.i.i, %.63.i

--- IR line 762 ---
      760:   %.24.i.i = fmul fast double %.22.i.i, %.63.i
      761:   %.25.i.i = fadd fast double %.24.i.i, 1.000000e+00
>>>   762:   %.29.i.i = fdiv fast double %.23.i.i, %.25.i.i
      763:   %6 = bitcast { double, double, double, double, i64, i64, i64, i64, i64 }* %arg.r_packet.1 to i8*
      764:   %sunkaddr = getelementptr inbounds i8, i8* %6, i64 8