Files
ROCm/python/test/tools/compare_files.py
Justin Lebar df08301e76 Reformat Python code with yapf. (#2589)
I've add an option to yapf to do what we want for long lines, see
https://github.com/google/yapf/pull/1177.  We can now have a real Python
formatter, yay!

To make this PR, I ran my modified yapf over the repository, then looked
over the full diff.  Where yapf was mangling the param list of long
function decls/calls (mostly kernels), I manually added `#` to put
linebreaks where we want.  I fixed up other formatting too -- mostly
adding or removing a trailing comma from lists.

Overall, trailing `#` was sufficient to get formatting similar to our
current code.  I didn't have to disable yapf anywhere.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
2023-11-02 20:44:17 -07:00

251 lines
8.2 KiB
Python

import argparse
import difflib
import glob
import os
import sys
from typing import Dict, List, Optional, Tuple
import yaml
class ComparisonResult:
def __init__(self, name: str, numComparisons: int, diffs: List[str] = None, errors: List[str] = None):
self.name = name
self.numComparisons = numComparisons
self.diffs = [] if diffs is None else diffs
self.errors = [] if errors is None else errors
def isSuccess(self) -> bool:
return len(self.diffs) == 0 and len(self.errors) == 0
def __str__(self) -> str:
return f"name={self.name}, numComparisons={self.numComparisons}, success={self.isSuccess()}"
def listFilesWithExtension(path: str, extension: str) -> List[str]:
"""
Returns a list of files in the given path with the given extension
The files are returned with their full path
"""
files = glob.glob(os.path.join(path, f'*.{extension}'))
return files
def getFileWithExtension(path: str, ext: str) -> Optional[str]:
"""
Returns a single file in the given path with the given extension
"""
# get all files in directory with extension
files = listFilesWithExtension(path, ext)
if len(files) == 0:
return None
# filter out files with grp in their name
files = [f for f in files if "__grp__" not in f]
if len(files) != 1:
print(f"Found {len(files)} files in {path} with extension {ext}!")
sys.exit(2)
return files[0]
def loadYamlFile(filePath: str) -> List[Dict[str, str]]:
"""
Loads a yaml file and returns its content as a list of dictionaries
"""
with open(filePath, 'r') as file:
content = yaml.safe_load(file)
return content
def compareFiles(file1: str, file2: str) -> bool:
"""
Compares two files and returns True if they are the same, False otherwise
"""
with open(file1, 'rb') as f1, open(file2, 'rb') as f2:
content1 = f1.read()
content2 = f2.read()
return content1 == content2
def diffFiles(file1, file2):
with open(file1, 'r') as f1:
file1_lines = f1.readlines()
with open(file2, 'r') as f2:
file2_lines = f2.readlines()
diff = list(difflib.unified_diff(file1_lines, file2_lines, file1, file2))
return diff
def getFileVec(path: str) -> List[Tuple[str, str]]:
"""
Returns a list of tuples (extension, file) for the given path (note: the path includes the hash)
The returned list must have extensions (json, ttir, ttgir)
in this particular order, unless a file with a certain extension does not exist
"""
vec = []
for ext in ["json", "ttir", "ttgir"]:
file = getFileWithExtension(path, ext)
if file is not None:
vec.append((ext, file))
return vec
def getNameToHashesDict(path: str) -> Dict[str, List[str]]:
"""
Returns a dictionary that maps kernel names to a list of hashes that have the same kernel name
in the given path
Note: the hashes must have a json file and either a ttir or ttgir file, otherwise they are ignored
"""
nameToHashes = {}
for hash in os.listdir(path):
fullPath = os.path.join(path, hash)
if not os.path.isdir(fullPath):
print(f"Path {fullPath} is not a directory!")
sys.exit(2)
fileVec = getFileVec(fullPath)
if len(fileVec) < 2 or fileVec[0][0] != "json":
continue
jsonFile = fileVec[0][1]
# load json file
with open(jsonFile, 'r') as file:
content = yaml.safe_load(file)
# get name
name = content["name"]
nameToHashes.setdefault(name, []).append(hash)
return nameToHashes
def doFilesMatch(path1: str, path2: str) -> bool:
"""
Returns True if the files in the given paths match, False otherwise
The files are considered to match if:
1. The number of files in both paths match
2. The json files match
3. Both paths have a ttir that match, if a ttir does not exist, the ttgir file must exist and match
"""
filesVec1 = getFileVec(path1)
filesVec2 = getFileVec(path2)
# The number of files must match
if len(filesVec1) != len(filesVec2):
return False
for (ext1, file1), (ext2, file2) in zip(filesVec1, filesVec2):
if ext1 != ext2:
return False
if not compareFiles(file1, file2):
return False
else:
# once we actually compared a ttir or ttgir file, we can break
if ext1 in ("ttir", "ttgir"):
break
return True
def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]],
args) -> ComparisonResult:
"""
Compare files with the given name in all hashes in both paths
Return the first mismatching files as a tuple (file1, file2), otherwise, return an empty tuple
"""
hashes1 = nameToHashes1.get(name, [])
hashes2 = nameToHashes2.get(name, [])
diffs = []
errors = []
numComparisons = 0
for hash1 in hashes1:
path1 = os.path.join(args.path1, hash1)
for hash2 in hashes2:
path2 = os.path.join(args.path2, hash2)
# check whether both paths have:
# 1. json files that match
# 2. ttir files that match (if they exist), otherwise ttgir files that match (if they exist)
# if any of these contraints is not met, then we can skip this pair of hashes since they are not a match
if not doFilesMatch(path1, path2):
continue
numComparisons += 1
extFile1 = listFilesWithExtension(path1, "ptx")[0]
extFile2 = listFilesWithExtension(path2, "ptx")[0]
diff = diffFiles(extFile1, extFile2)
if len(diff) > 0:
diffs.append(diffFiles(extFile2, extFile1))
if numComparisons == 0:
errors.append(f"Did not find any matching files for {name}")
return ComparisonResult(name=name, numComparisons=numComparisons, diffs=diffs, errors=errors)
def dumpResults(results: List[ComparisonResult], fileName: str):
"""
Dumps the results to the given file
"""
with open(fileName, 'w') as file:
for result in results:
file.write(str(result) + "\n")
file.write("Diffs:\n")
for diff in result.diffs:
for line in diff:
file.write(line)
file.write("Errors:\n")
for error in result.errors:
file.write(error)
file.write("\n\n")
def main(args) -> bool:
"""
Iterates over all kernels in the given yaml file and compares them
in the given paths
"""
if args.path1 == args.path2:
print("Cannot compare files in the same directory!")
sys.exit(2)
# Get kernel name to hashes dict, these hashes would have the same kernel name
nameToHashes1 = getNameToHashesDict(args.path1)
nameToHashes2 = getNameToHashesDict(args.path2)
# Get all kernels that need to be checked
kernelNames = set(nameToHashes1.keys()).union(set(nameToHashes2.keys()))
results = []
# iterate over the kernels that need to be checked
for name in kernelNames:
# Compare all hashes on path 1 with all hashes on path 2
# result is either the mismatching (file1, file2) with "extension" or empty tuple if no mismatch
result = compareMatchingFiles(name, nameToHashes1, nameToHashes2, args)
print(result)
# Otherwise, add it to the mismatches
results.append(result)
# Dump results
dumpResults(results, "kernels_reference_check.txt")
success = all(result.isSuccess() for result in results)
if not success:
print("Failed!")
sys.exit(1)
print("Passed!")
sys.exit(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--path1",
type=str,
default=None,
required=True,
help=("Path to first cache directory"),
)
parser.add_argument(
"--path2",
type=str,
default=None,
required=True,
help=("Path to second cache directory"),
)
args = parser.parse_args()
main(args)