get_single_element [pr] (#8328)

This commit is contained in:
George Hotz
2024-12-18 22:23:45 -08:00
committed by GitHub
parent 423d823c50
commit 3a9ca62b9e
3 changed files with 7 additions and 7 deletions

View File

@@ -6,7 +6,7 @@ import sys
from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
from tinygrad.renderer import Renderer
@@ -114,9 +114,7 @@ class PythonProgram:
elif uop is Ops.ASSIGN:
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
ul[i] = inp[0]
elif uop is Ops.GEP:
assert len(arg) == 1
ul[i] = inp[0][arg[0]]
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
elif uop is Ops.WMMA:
# here are the models for the WMMA instruction on the different hardware
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):