debug print shapetrackers

This commit is contained in:
George Hotz
2023-02-28 08:11:40 -08:00
parent cfa5a12f13
commit a8bbcccc16
2 changed files with 6 additions and 6 deletions

View File

@@ -108,8 +108,11 @@ class ASTKernel:
return cache[x]
print_ast(self.input_ast, "ast")
def printbufs(self, prefix=""):
def printbufs(self, prefix="", print_shapetrackers=False):
print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}")
if print_shapetrackers:
for st in self.sts:
print(st)
for i in range(len(self.sts)):
print(prefix, self.buftokens[i], f"early:{'T' if i < len(self.bufs) and self.bufs[i] in self.earlybufs else 'F'}", self.sts[i].shape, self.sts[i].views[-1].strides, len(self.sts[i].views), type(self.bufs[i]._buf) if i < len(self.bufs) else "FAKE")

View File

@@ -236,7 +236,7 @@ class CLASTKernel(ASTKernel):
def codegen(self) -> Callable:
self.process()
self.upcast_in_mid_reduce = False
if DEBUG >= 3: self.printbufs("old:")
if DEBUG >= 3: self.printbufs("old:", DEBUG>=4)
if KOPT == -1 or IMAGE == 2: self.hand_coded_optimizations()
# add a local buffer for multistage reduce
@@ -247,10 +247,7 @@ class CLASTKernel(ASTKernel):
self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce
if DEBUG >= 3:
print("output shape", self.output_shape)
if DEBUG >= 4:
for b in self.bufs:
print(b.st)
self.printbufs("new:")
self.printbufs("new:", DEBUG>=4)
self.bufs_to_delete : Set[int] = set()
self.loaded_keys : Dict[Tuple[int,int], Token] = {}