mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
unroll axis 0 in tensor core (#11155)
* unroll is 0 in tc [pr]
* flip order of upcast/reduce in tensor core
* Revert "flip order of upcast/reduce in tensor core"
This reverts commit e564e38bcd.
This commit is contained in:
@@ -389,7 +389,7 @@ class Kernel:
|
||||
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
||||
except KernelOptError: continue
|
||||
# tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
|
||||
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
|
||||
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, 0, amt), append_opt=False) # TODO: this should be the reduce, not 0
|
||||
for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
|
||||
self.tensor_core = tc
|
||||
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
||||
@@ -464,7 +464,7 @@ class Kernel:
|
||||
def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
|
||||
upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
|
||||
return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
|
||||
def get_tc_swizzle_st(shape, local_perm, reduce_perm, upcast_perm):
|
||||
def get_tc_swizzle_st(shape, local_perm, upcast_perm, reduce_perm):
|
||||
ru_perm = reduce_perm + upcast_perm
|
||||
offset = (tcd - (wd + len(local_perm)))
|
||||
permaxis = list(range(wd)) \
|
||||
|
||||
@@ -18,7 +18,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
|
||||
fwd_st = [f"l{i}" for i in range(local_axes)] + [f"r{i}" for i in range(reduce_axes)] + [f"u{i}" for i in range(upcast_axes)]
|
||||
st = {s:i for i,s in enumerate(fwd_st)}
|
||||
return tuple((tuple([st[x] for x in s[0]]), tuple([st[x] for x in s[2]]), tuple([st[x] for x in s[1]])) for s in self.swizzle)
|
||||
return tuple((tuple([st[x] for x in s[0]]), tuple([st[x] for x in s[1]]), tuple([st[x] for x in s[2]])) for s in self.swizzle)
|
||||
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
|
||||
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
|
||||
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
|
||||
|
||||
Reference in New Issue
Block a user