mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-13 08:05:10 -05:00
use is to compare with enum (#3993)
* use is to compare with enum currently it's mixed between `==` and `is`, moved all to `is` * more
This commit is contained in:
@@ -48,17 +48,17 @@ class DiskRunner(JITRunner):
|
||||
skip_allocation = True
|
||||
def __init__(self, ast:LazyOp):
|
||||
# two ASTs are allowed here.
|
||||
assert ast.op == BufferOps.STORE, "output of AST must be store"
|
||||
assert ast.op is BufferOps.STORE, "output of AST must be store"
|
||||
assert ast.arg.st.contiguous, "shapetracker must be contiguous"
|
||||
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
|
||||
if ast.src[0].op == UnaryOps.CAST:
|
||||
if ast.src[0].op is UnaryOps.CAST:
|
||||
top_src = ast.src[0].src[0]
|
||||
assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts"
|
||||
self.new_dtype = ast.src[0].arg[0]
|
||||
else:
|
||||
top_src = ast.src[0]
|
||||
self.new_dtype = top_src.arg.dtype
|
||||
assert top_src.op == BufferOps.LOAD, "top of AST must be load"
|
||||
assert top_src.op is BufferOps.LOAD, "top of AST must be load"
|
||||
assert len(top_src.arg.st.views) == 1, "shapetracker must have 1 view"
|
||||
view = top_src.arg.st.views[0]
|
||||
assert view.mask is None, "view cannot have a mask"
|
||||
|
||||
Reference in New Issue
Block a user