mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-09 06:05:11 -05:00
make EMULATE a context var (#12002)
* make EMULATE a context var * fix test amx
This commit is contained in:
@@ -2,10 +2,10 @@
|
||||
# a python uops emulator
|
||||
# works to test the tensor cores, and all the uops in general
|
||||
# this is the (living) definition of uops
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING, cast
|
||||
import pickle, base64, itertools, time, struct, sys
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp
|
||||
@@ -210,14 +210,17 @@ class PythonRenderer(Renderer):
|
||||
device = "PYTHON"
|
||||
code_for_op = python_alu
|
||||
def __init__(self):
|
||||
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", tc.metal
|
||||
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", tc.amd_rdna3
|
||||
if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", tc.amd_cdna
|
||||
if getenv("EMULATE_AMD_RDNA4"): self.device, self.tensor_cores = "AMD", tc.amd_rdna4
|
||||
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
|
||||
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
|
||||
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
|
||||
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", tc.amx
|
||||
match cast(str, EMULATE.value):
|
||||
case "METAL": self.device, self.tensor_cores = "METAL", tc.metal
|
||||
case "AMD": self.device, self.tensor_cores = "AMD", tc.amd_rdna3
|
||||
case "AMD_MFMA": self.device, self.tensor_cores = "AMD", tc.amd_cdna
|
||||
case "AMD_RDNA4": self.device, self.tensor_cores = "AMD", tc.amd_rdna4
|
||||
case "CUDA": self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
|
||||
case "CUDA_SM75": self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
|
||||
case "INTEL": self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
|
||||
case "AMX": self.device, self.tensor_cores = "CPU", tc.amx
|
||||
case "": pass
|
||||
case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}")
|
||||
|
||||
def render(self, uops:list[UOp]) -> str:
|
||||
# the value of SPECIAL comes from local/global_size, not form its source
|
||||
|
||||
Reference in New Issue
Block a user