This commit is contained in:
borgwang
2025-06-19 21:13:31 +08:00
committed by GitHub
parent ac891b78f8
commit 06ea74bf2c
6 changed files with 6 additions and 6 deletions

View File

@@ -153,7 +153,7 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]:
assert False
def fast_idiv(ctx: Renderer|None, x: UOp, d: int) -> UOp|None:
# idiv is truncated division, but arithmatic shift is floored division, so can only do non-negative numbers!
# idiv is truncated division, but arithmetic shift is floored division, so can only do non-negative numbers!
if x.vmin<0: return None
sign = 1 if d > 0 else -1
m,s = magicgu(vmax := min(x.vmax, dtypes.max(x.dtype)), abs(d))

View File

@@ -134,7 +134,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
# REDUCE supports both "horizonal" reduction and range reduction. the horizonal elements are taken in the nearest group
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), alu_op)
def lower_load_store(ctx: IndexContext, x: UOp, buf: UOp):

View File

@@ -47,7 +47,7 @@ class Optimizer:
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
if self.fused:
# optimizer fusion just concatentates all the buffers, runs the _step, then splits them back up
# optimizer fusion just concatenates all the buffers, runs the _step, then splits them back up
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
[Tensor.cat(*[unwrap(t.grad).flatten() for t in self.params], dim=0)])
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]

View File

@@ -26,7 +26,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
dtype_in: DType # dtype for A and B
dtype_out: DType # dtype for C and D
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifing kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifying kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
swizzle: tuple[Optional[tuple[tuple[int, ...], tuple[int, ...]]], Optional[tuple[tuple[int, ...], tuple[int, ...]]]] = (None, None)
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]

View File

@@ -216,7 +216,7 @@ class ClangRenderer(CStyleLanguage):
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
]
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
# to just jump at the start of a shellcode whithout having to deal with symbols or trampolines at all. This is better than having to inline
# to just jump at the start of a shellcode without having to deal with symbols or trampolines at all. This is better than having to inline
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}

View File

@@ -337,7 +337,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
if v0 == v1:
uop = uop.substitute({expr:expr.const_like(v0)}).simplify()
continue
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output