mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
CLANG fixed ops python [run_process_replay] (#6866)
* hotfix: fixed values in ops_python for AMX * hotfix: remove unused import
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
from typing import Tuple, List, Optional, Any, Dict
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten, prod
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate, UOps, UOp
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -180,10 +180,9 @@ class PythonProgram:
|
||||
def c_map(lane, elem): return (lane, elem)
|
||||
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "CLANG":
|
||||
wmma_sz = [prod(x[1] for x in l) for l in arg[6]]
|
||||
def elem(x, i, j, _): return x[i+j][0]
|
||||
def c_map(_, elem): return (elem%wmma_sz[0], elem//wmma_sz[0])
|
||||
ul[i] = wmma_helper(1, 1, wmma_sz[0], wmma_sz[1], wmma_sz[2], elem, elem, c_map)
|
||||
def c_map(_, elem): return (elem%16, elem//16)
|
||||
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop is UOps.ALU:
|
||||
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
|
||||
|
||||
Reference in New Issue
Block a user