mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-16 17:45:38 -05:00
get_single_element [pr] (#8328)
This commit is contained in:
@@ -6,7 +6,7 @@ import sys
|
||||
from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -114,9 +114,7 @@ class PythonProgram:
|
||||
elif uop is Ops.ASSIGN:
|
||||
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
|
||||
ul[i] = inp[0]
|
||||
elif uop is Ops.GEP:
|
||||
assert len(arg) == 1
|
||||
ul[i] = inp[0][arg[0]]
|
||||
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
|
||||
elif uop is Ops.WMMA:
|
||||
# here are the models for the WMMA instruction on the different hardware
|
||||
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
||||
|
||||
Reference in New Issue
Block a user