mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge branch 'triton-mlir' into ifu-231117
This commit is contained in:
@@ -1190,7 +1190,7 @@ def is_hip():
|
||||
|
||||
def mfma_supported_granularity(m, n, k) -> bool:
|
||||
# todo make this gran_type matrix element type sensitive
|
||||
for gran_type in [(32, 8), (16, 16)]:
|
||||
for gran_type in [(32, 8), (16, 16), (4, 64)]:
|
||||
granularity_mn, granularity_k = gran_type
|
||||
|
||||
if m % granularity_mn != 0 or n % granularity_mn != 0:
|
||||
@@ -1261,11 +1261,15 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_nu
|
||||
|
||||
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
|
||||
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
|
||||
assert lhs.shape[1].value == rhs.shape[
|
||||
0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
|
||||
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
if _is_cuda(builder.target):
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
|
||||
if is_hip():
|
||||
assert lhs.shape[0].value >= 4 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 4, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 4!"
|
||||
|
||||
# hip for now converts fp8 to fp16 for mixed input
|
||||
if is_hip():
|
||||
|
||||
Reference in New Issue
Block a user