Skip to content

Commit

Permalink
[BE] Rewrite check_binary_symbols as Python script
Browse files Browse the repository at this point in the history
This makes code a bit more readable and make verification run so much faster as we can cache the results on python side and also run matching concurrently
  • Loading branch information
malfet committed Sep 4, 2024
1 parent 1b54283 commit 098ded5
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 74 deletions.
75 changes: 1 addition & 74 deletions check_binary.sh
Original file line number Diff line number Diff line change
Expand Up @@ -123,81 +123,8 @@ if [[ "$(uname)" != 'Darwin' ]]; then

# We also check that there are [not] cxx11 symbols in libtorch
#
# To check whether it is using cxx11 ABI, check non-existence of symbol:
PRE_CXX11_SYMBOLS=(
"std::basic_string<"
"std::list"
)
# To check whether it is using pre-cxx11 ABI, check non-existence of symbol:
CXX11_SYMBOLS=(
"std::__cxx11::basic_string"
"std::__cxx11::list"
)
# NOTE: Checking the above symbols in all namespaces doesn't work, because
# devtoolset7 always produces some cxx11 symbols even if we build with old ABI,
# and CuDNN always has pre-cxx11 symbols even if we build with new ABI using gcc 5.4.
# Instead, we *only* check the above symbols in the following namespaces:
LIBTORCH_NAMESPACE_LIST=(
"c10::"
"at::"
"caffe2::"
"torch::"
)
echo "Checking that symbols in libtorch.so have the right gcc abi"
grep_symbols () {
symbols=("$@")
for namespace in "${LIBTORCH_NAMESPACE_LIST[@]}"
do
for symbol in "${symbols[@]}"
do
nm "$lib" | c++filt | grep " $namespace".*$symbol
done
done
}
check_lib_symbols_for_abi_correctness () {
lib=$1
echo "lib: " $lib
if [[ "$DESIRED_DEVTOOLSET" == *"cxx11-abi"* ]]; then
num_pre_cxx11_symbols=$(grep_symbols "${PRE_CXX11_SYMBOLS[@]}" | wc -l) || true
echo "num_pre_cxx11_symbols: " $num_pre_cxx11_symbols
if [[ "$num_pre_cxx11_symbols" -gt 0 ]]; then
echo "Found pre-cxx11 symbols but there shouldn't be. Dumping symbols"
grep_symbols "${PRE_CXX11_SYMBOLS[@]}"
exit 1
fi
num_cxx11_symbols=$(grep_symbols "${CXX11_SYMBOLS[@]}" | wc -l) || true
echo "num_cxx11_symbols: " $num_cxx11_symbols
if [[ "$num_cxx11_symbols" -lt 1000 ]]; then
echo "Didn't find enough cxx11 symbols. Aborting."
exit 1
fi
else
num_cxx11_symbols=$(grep_symbols "${CXX11_SYMBOLS[@]}" | wc -l) || true
echo "num_cxx11_symbols: " $num_cxx11_symbols
if [[ "$num_cxx11_symbols" -gt 0 ]]; then
echo "Found cxx11 symbols but there shouldn't be. Dumping symbols"
grep_symbols "${CXX11_SYMBOLS[@]}"
exit 1
fi
num_pre_cxx11_symbols=$(grep_symbols "${PRE_CXX11_SYMBOLS[@]}" | wc -l) || true
echo "num_pre_cxx11_symbols: " $num_pre_cxx11_symbols
if [[ "$num_pre_cxx11_symbols" -lt 1000 ]]; then
echo "Didn't find enough pre-cxx11 symbols. Aborting."
exit 1
fi
fi
}
# After https://github.com/pytorch/pytorch/pull/29731 most of the real
# libtorch code will live in libtorch_cpu, not libtorch, so cxx11
# symbol counting won't work on libtorch (since there's nothing in
# it.) Fortunately, libtorch_cpu.so doesn't exist prior to this PR,
# so just test if the file exists and use it if it does.
if [ -f "${install_root}/lib/libtorch_cpu.so" ]; then
libtorch="${install_root}/lib/libtorch_cpu.so"
else
libtorch="${install_root}/lib/libtorch.so"
fi
check_lib_symbols_for_abi_correctness $libtorch
python test/check_binary_symbols.py

echo "cxx11 symbols seem to be in order"
fi # if on Darwin
Expand Down
85 changes: 85 additions & 0 deletions test/check_binary_symbols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python3
import concurrent.futures
import distutils.sysconfig
import itertools
import functools
import os
import re
from pathlib import Path

# We also check that there are [not] cxx11 symbols in libtorch
#
# To check whether it is using cxx11 ABI, check non-existence of symbol:
PRE_CXX11_SYMBOLS=(
"std::basic_string<",
"std::list",
)
# To check whether it is using pre-cxx11 ABI, check non-existence of symbol:
CXX11_SYMBOLS=(
"std::__cxx11::basic_string",
"std::__cxx11::list",
)
# NOTE: Checking the above symbols in all namespaces doesn't work, because
# devtoolset7 always produces some cxx11 symbols even if we build with old ABI,
# and CuDNN always has pre-cxx11 symbols even if we build with new ABI using gcc 5.4.
# Instead, we *only* check the above symbols in the following namespaces:
LIBTORCH_NAMESPACE_LIST=(
"c10::",
"at::",
"caffe2::",
"torch::",
)

LIBTORCH_CXX11_PATTERNS = [re.compile(f"{x}.*{y}") for (x,y) in itertools.product(LIBTORCH_NAMESPACE_LIST, CXX11_SYMBOLS)]

LIBTORCH_PRE_CXX11_PATTERNS = [re.compile(f"{x}.*{y}") for (x,y) in itertools.product(LIBTORCH_NAMESPACE_LIST, PRE_CXX11_SYMBOLS)]

@functools.lru_cache
def get_symbols(lib :str ) -> list[tuple[str, str, str]]:
from subprocess import check_output
lines = check_output(f'nm "{lib}"|c++filt', shell=True)
return [x.split(' ', 2) for x in lines.decode('latin1').split('\n')[:-1]]


def count_symbols(lib: str, patterns: list[re.Match]) -> int:
def _count_symbols(symbols: list[tuple[str, str, str]], patterns: list[str]) -> int:
rc = 0
for s_addr, s_type, s_name in symbols:
for pattern in patterns:
if pattern.match(s_name):
rc += 1
return rc
all_symbols = get_symbols(lib)
num_workers= 32
chunk_size = (len(all_symbols) + num_workers - 1 ) // num_workers
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
tasks = [executor.submit(_count_symbols, all_symbols[i * chunk_size : (i + 1) * chunk_size], patterns) for i in range(num_workers)]
return sum(x.result() for x in tasks)

def check_lib_symbols_for_abi_correctness(lib: str, pre_cxx11_abi: bool = True) -> None:
print(f"lib: {lib}")
num_cxx11_symbols = count_symbols(lib, LIBTORCH_CXX11_PATTERNS)
num_pre_cxx11_symbols = count_symbols(lib, LIBTORCH_PRE_CXX11_PATTERNS)
if pre_cxx11_abi:
if num_cxx11_symbols > 0:
raise RuntimeError("Found cxx11 symbols, but there shouldn't be any")
if num_pre_cxx11_symbols < 1000:
raise RuntimeError("Didn't find enough pre-cxx11 symbols.")
else:
if num_pre_cxx11_symbols > 0:
raise RuntimeError("Found pre-cxx11 symbols, but there shouldn't be any")
if num_cxx11_symbols < 100:
raise RuntimeError("Didn't find enought cxx11 symbols")

def main() -> None:
if os.getenv("PACKAGE_TYPE") == "libtorch":
install_root = Path(__file__).parent.parent
else:
install_root = Path(distutils.sysconfig.get_python_lib()) / "torch"
libtorch_cpu_path = install_root / "lib" / "libtorch_cpu.so"
pre_cxx11_abi = "cxx11-abi" not in os.getenv("DESIRED_DEVTOOLSET", "")
check_lib_symbols_for_abi_correctness(libtorch_cpu_path, pre_cxx11_abi)


if __name__ == "__main__":
main()

0 comments on commit 098ded5

Please sign in to comment.