mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
fix ptx linearizer bug [pr] (#9926)
* fix ptx bug * align 16 * revert align because it breaks pr * smallest diff that fixes ptx bug
This commit is contained in:
@@ -3,7 +3,7 @@ import collections, heapq
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp
|
||||
from tinygrad.spec import type_verify
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, flatten, partition
|
||||
|
||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block}
|
||||
@@ -192,10 +192,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
||||
if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
|
||||
# don't flow (fully) through assign and store
|
||||
elif s.op is Ops.STORE:
|
||||
# ugh, deal with non-reduce locals. probably wrong
|
||||
if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local:
|
||||
idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
|
||||
this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
|
||||
idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
|
||||
this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
|
||||
elif s.op is Ops.ASSIGN:
|
||||
# flow though assign, but remove the ranges used in the assign
|
||||
assert s.src[0].op is Ops.DEFINE_ACC
|
||||
|
||||
Reference in New Issue
Block a user