make EMULATE a context var (#12002)

* make EMULATE a context var

* fix test amx
This commit is contained in:
George Hotz
2025-09-04 11:15:43 -07:00
committed by GitHub
parent 09106e4aae
commit 9dee724fc4
5 changed files with 55 additions and 49 deletions

View File

@@ -2,10 +2,10 @@
# a python uops emulator
# works to test the tensor cores, and all the uops in general
# this is the (living) definition of uops
from typing import Any, TYPE_CHECKING
from typing import Any, TYPE_CHECKING, cast
import pickle, base64, itertools, time, struct, sys
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp
@@ -210,14 +210,17 @@ class PythonRenderer(Renderer):
device = "PYTHON"
code_for_op = python_alu
def __init__(self):
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", tc.metal
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", tc.amd_rdna3
if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", tc.amd_cdna
if getenv("EMULATE_AMD_RDNA4"): self.device, self.tensor_cores = "AMD", tc.amd_rdna4
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", tc.amx
match cast(str, EMULATE.value):
case "METAL": self.device, self.tensor_cores = "METAL", tc.metal
case "AMD": self.device, self.tensor_cores = "AMD", tc.amd_rdna3
case "AMD_MFMA": self.device, self.tensor_cores = "AMD", tc.amd_cdna
case "AMD_RDNA4": self.device, self.tensor_cores = "AMD", tc.amd_rdna4
case "CUDA": self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
case "CUDA_SM75": self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
case "INTEL": self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
case "AMX": self.device, self.tensor_cores = "CPU", tc.amx
case "": pass
case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}")
def render(self, uops:list[UOp]) -> str:
# the value of SPECIAL comes from local/global_size, not form its source