mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix broken contiguous
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -161,7 +161,7 @@ jobs:
|
||||
run: GPU=1 IMAGE=2 python3 test/test_ops.py
|
||||
- name: Test openpilot model
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=196 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
ALLOWED_KERNEL_COUNT=200 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
UNSAFE_FLOAT4=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
|
||||
# disabled, this test is flaky
|
||||
|
||||
@@ -46,7 +46,9 @@ def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : Optional[bool] = None)
|
||||
if not DEBUG and not show_graph: return
|
||||
op : List[Op] = [x.op for x in get_lazyops(ast)]
|
||||
inp : List[DeviceBuffer] = get_buffers(ast)
|
||||
if len(inp) == 1 and inp[0] == ret: return # don't log self loops
|
||||
if len(inp) == 1 and inp[0] == ret:
|
||||
if nm(ret) in G.nodes: G.nodes[nm(ret)]['style'] += ', bold'
|
||||
return # don't log self loops
|
||||
oporder = [LoadOps, FusedOps, ProcessingOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
|
||||
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
||||
cnts[optype] += 1
|
||||
|
||||
@@ -83,8 +83,8 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
|
||||
def get_weakop(op:LazyOp) -> LazyOp: return LazyOp(op.op, tuple(get_weakop(x) if isinstance(x, LazyOp) else weakref.ref(x) for x in op.src), op.arg)
|
||||
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.op.src[0]) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
|
||||
def get_movementroot(root:LazyBuffer) -> LazyBuffer: return get_movementroot(root.op.src[0]) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot(x) if x.optype == MovementOps and x.st.contiguous else x
|
||||
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(root.op.src[0], allow_contiguous) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(x.op.src[0]) if x.realized is None and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
||||
|
||||
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)
|
||||
@@ -274,7 +274,7 @@ class LazyBuffer:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
|
||||
return root.movement_op(MovementOps.RESHAPE, ret.st.shape) if ret.st.shape != root.shape else root
|
||||
return root.movement_op(MovementOps.RESHAPE, ret.st.shape)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user