mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix-typos (#10879)
This commit is contained in:
@@ -153,7 +153,7 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]:
|
||||
assert False
|
||||
|
||||
def fast_idiv(ctx: Renderer|None, x: UOp, d: int) -> UOp|None:
|
||||
# idiv is truncated division, but arithmatic shift is floored division, so can only do non-negative numbers!
|
||||
# idiv is truncated division, but arithmetic shift is floored division, so can only do non-negative numbers!
|
||||
if x.vmin<0: return None
|
||||
sign = 1 if d > 0 else -1
|
||||
m,s = magicgu(vmax := min(x.vmax, dtypes.max(x.dtype)), abs(d))
|
||||
|
||||
@@ -134,7 +134,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
# REDUCE supports both "horizonal" reduction and range reduction. the horizonal elements are taken in the nearest group
|
||||
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
|
||||
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), alu_op)
|
||||
|
||||
def lower_load_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
|
||||
@@ -47,7 +47,7 @@ class Optimizer:
|
||||
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
|
||||
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
|
||||
if self.fused:
|
||||
# optimizer fusion just concatentates all the buffers, runs the _step, then splits them back up
|
||||
# optimizer fusion just concatenates all the buffers, runs the _step, then splits them back up
|
||||
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
|
||||
[Tensor.cat(*[unwrap(t.grad).flatten() for t in self.params], dim=0)])
|
||||
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
|
||||
|
||||
@@ -26,7 +26,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||
elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
|
||||
dtype_in: DType # dtype for A and B
|
||||
dtype_out: DType # dtype for C and D
|
||||
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifing kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
|
||||
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifying kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
|
||||
swizzle: tuple[Optional[tuple[tuple[int, ...], tuple[int, ...]]], Optional[tuple[tuple[int, ...], tuple[int, ...]]]] = (None, None)
|
||||
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
|
||||
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
|
||||
|
||||
@@ -216,7 +216,7 @@ class ClangRenderer(CStyleLanguage):
|
||||
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
|
||||
]
|
||||
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
|
||||
# to just jump at the start of a shellcode whithout having to deal with symbols or trampolines at all. This is better than having to inline
|
||||
# to just jump at the start of a shellcode without having to deal with symbols or trampolines at all. This is better than having to inline
|
||||
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
|
||||
prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
|
||||
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
||||
|
||||
@@ -337,7 +337,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
||||
if v0 == v1:
|
||||
uop = uop.substitute({expr:expr.const_like(v0)}).simplify()
|
||||
continue
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
|
||||
Reference in New Issue
Block a user