mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
clean up more RANGEIFY flag (#12556)
This commit is contained in:
@@ -2,7 +2,7 @@ import time, math, unittest, functools, platform, warnings
|
||||
import numpy as np
|
||||
from typing import List, Callable
|
||||
import torch
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM, RANGEIFY
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -3040,7 +3040,6 @@ class TestOps(unittest.TestCase):
|
||||
pos_weight=torch.tensor(pos_weight)),
|
||||
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1),pos_weight=Tensor(pos_weight)))
|
||||
|
||||
@unittest.skipIf(RANGEIFY > 1, "broken on RANGEIFY > 1, TODO: fix")
|
||||
def test_cross_entropy_class_probabilities(self):
|
||||
helper_test_op([(32,), (32,)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y))
|
||||
|
||||
Reference in New Issue
Block a user