CLANG fixed ops python [run_process_replay] (#6866)

* hotfix: fixed values in ops_python for AMX

* hotfix: remove unused import
This commit is contained in:
ignaciosica
2024-10-03 09:40:04 -03:00
committed by GitHub
parent 4b6732c4f6
commit 8931f20765

View File

@@ -5,7 +5,7 @@
from typing import Tuple, List, Optional, Any, Dict
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv, flatten, prod
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate, UOps, UOp
from tinygrad.renderer import Renderer
@@ -180,10 +180,9 @@ class PythonProgram:
def c_map(lane, elem): return (lane, elem)
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg[4] == "CLANG":
wmma_sz = [prod(x[1] for x in l) for l in arg[6]]
def elem(x, i, j, _): return x[i+j][0]
def c_map(_, elem): return (elem%wmma_sz[0], elem//wmma_sz[0])
ul[i] = wmma_helper(1, 1, wmma_sz[0], wmma_sz[1], wmma_sz[2], elem, elem, c_map)
def c_map(_, elem): return (elem%16, elem//16)
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop is UOps.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"