feat: add support for np.ones_like and np.zeros_like

This commit is contained in:
Umut
2022-05-17 10:58:33 +02:00
parent 6c20ed7a08
commit 56085204e6
3 changed files with 32 additions and 0 deletions

View File

@@ -210,6 +210,7 @@ class Tracer:
np.negative,
np.nextafter,
np.not_equal,
np.ones_like,
np.positive,
np.power,
np.rad2deg,
@@ -234,12 +235,16 @@ class Tracer:
np.true_divide,
np.trunc,
np.where,
np.zeros_like,
}
SUPPORTED_KWARGS: Dict[Any, Set[str]] = {
np.concatenate: {
"axis",
},
np.ones_like: {
"dtype",
},
np.reshape: {
"newshape",
},
@@ -247,6 +252,9 @@ class Tracer:
"axis",
"keepdims",
},
np.zeros_like: {
"dtype",
},
}
@staticmethod
@@ -279,6 +287,14 @@ class Tracer:
f"Function 'np.{operation.__name__}' is not supported with kwarg '{kwarg}'"
)
if operation == np.ones_like: # pylint: disable=comparison-with-callable
dtype = kwargs.get("dtype", np.int64)
return Tracer(Node.constant(np.ones(args[0].shape, dtype=dtype)), [])
if operation == np.zeros_like: # pylint: disable=comparison-with-callable
dtype = kwargs.get("dtype", np.int64)
return Tracer(Node.constant(np.zeros(args[0].shape, dtype=dtype)), [])
def sampler(arg: Any) -> Any:
if isinstance(arg, tuple):
return tuple(sampler(item) for item in arg)

View File

@@ -69,6 +69,7 @@ List of supported functions:
- negative
- nextafter
- not_equal
- ones_like
- positive
- power
- rad2deg
@@ -93,6 +94,7 @@ List of supported functions:
- true_divide
- trunc
- where
- zeros_like
<!--- gen_supported_ufuncs.py: inject supported operations [END] -->
# Shapes

View File

@@ -447,6 +447,20 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
},
id="np.where(x < 5, x * 3, x)",
),
pytest.param(
lambda x: x + np.ones_like(x),
{
"x": {"status": "encrypted", "range": [0, 10]},
},
id="x + np.ones_like(x)",
),
pytest.param(
lambda x: x + np.zeros_like(x),
{
"x": {"status": "encrypted", "range": [0, 10]},
},
id="x + np.zeros_like(x)",
),
],
)
def test_others(function, parameters, helpers):