chore: integrate ruff

This commit is contained in:
Umut
2022-12-23 11:11:50 +01:00
parent 85f36817cf
commit 6a37249ec5
38 changed files with 327 additions and 191 deletions

View File

@@ -62,8 +62,12 @@ flake8:
poetry run flake8 --max-line-length 100 --per-file-ignores="__init__.py:F401" \
$(SRC_DIR)/ tests/ script/
.PHONY: ruff
ruff:
poetry run ruff $(SRC_DIR)/ tests/ script/
.PHONY: python_linting # Run python linters
python_linting: pylint flake8
python_linting: pylint flake8 ruff
.PHONY: conformance # Run command to fix some conformance issues automatically
conformance: finalize_nb python_format supported_functions licenses

View File

@@ -149,7 +149,7 @@ class DebugArtifacts:
next(dependencies)
for dependency in dependencies:
tokens = [token for token in dependency.split(" ") if token != ""]
tokens = [token for token in dependency.split(" ") if token != ""] # noqa: S105
if len(tokens) == 0:
continue

View File

@@ -75,7 +75,8 @@ class Circuit:
"""
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `keygen` method")
message = "Virtual circuits cannot use `keygen` method"
raise RuntimeError(message)
self.client.keygen(force)
@@ -93,7 +94,8 @@ class Circuit:
"""
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `encrypt` method")
message = "Virtual circuits cannot use `encrypt` method"
raise RuntimeError(message)
return self.client.encrypt(*args)
@@ -111,7 +113,8 @@ class Circuit:
"""
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `run` method")
message = "Virtual circuits cannot use `run` method"
raise RuntimeError(message)
self.keygen(force=False)
return self.server.run(args, self.client.evaluation_keys)
@@ -133,7 +136,8 @@ class Circuit:
"""
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `decrypt` method")
message = "Virtual circuits cannot use `decrypt` method"
raise RuntimeError(message)
return self.client.decrypt(result)

View File

@@ -123,7 +123,8 @@ class Client:
input_specs = client_parameters_json["inputs"]
if len(args) != len(input_specs):
raise ValueError(f"Expected {len(input_specs)} inputs but got {len(args)}")
message = f"Expected {len(input_specs)} inputs but got {len(args)}"
raise ValueError(message)
sanitized_args: Dict[int, Union[int, np.ndarray]] = {}
for index, spec in enumerate(input_specs):
@@ -168,9 +169,10 @@ class Client:
if not is_valid:
actual_value = Value.of(arg, is_encrypted=is_encrypted)
raise ValueError(
message = (
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
)
raise ValueError(message)
self.keygen(force=False)
return ClientSupport.encrypt_arguments(

View File

@@ -115,12 +115,13 @@ class Compiler:
if len(missing_args) != 1:
parameter_str += f" and {repr(missing_args[-1])}"
raise ValueError(
message = (
f"Encryption status{'es' if len(missing_args) > 1 else ''} "
f"of parameter{'s' if len(missing_args) > 1 else ''} "
f"{parameter_str} of function '{function.__name__}' "
f"{'are' if len(missing_args) > 1 else 'is'} not provided"
)
raise ValueError(message)
additional_args = list(parameter_encryption_statuses)
for arg in signature.parameters.keys():
@@ -134,12 +135,13 @@ class Compiler:
if len(additional_args) != 1:
parameter_str += f" and {repr(additional_args[-1])}"
raise ValueError(
message = (
f"Encryption status{'es' if len(additional_args) > 1 else ''} "
f"of {parameter_str} {'are' if len(additional_args) > 1 else 'is'} provided but "
f"{'they are' if len(additional_args) > 1 else 'it is'} not a parameter "
f"of function '{function.__name__}'"
)
raise ValueError(message)
self.function = function # type: ignore
self.parameter_encryption_statuses = {
@@ -168,9 +170,8 @@ class Compiler:
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
]:
if len(kwargs) != 0:
raise RuntimeError(
f"Calling function '{self.function.__name__}' with kwargs is not supported"
)
message = f"Calling function '{self.function.__name__}' with kwargs is not supported"
raise RuntimeError(message)
sample = args[0] if len(args) == 1 else args
@@ -257,10 +258,11 @@ class Compiler:
"a single value" if len(sample) == 1 else f"a tuple of {len(sample)} values"
)
raise ValueError(
message = (
f"Input #{index} of your inputset is not well formed "
f"(expected {expected} got {actual})"
)
raise ValueError(message)
if self.configuration.auto_adjust_rounders:
AutoRounder.adjust(self.function, self.inputset)
@@ -269,10 +271,11 @@ class Compiler:
try:
first_sample = next(iter(self.inputset))
except StopIteration as error:
raise RuntimeError(
message = (
f"{action} function '{self.function.__name__}' "
f"without an inputset is not supported"
) from error
)
raise RuntimeError(message) from error
self._trace(first_sample)
assert self.graph is not None

View File

@@ -41,19 +41,16 @@ class Configuration:
if not self.enable_unsafe_features:
if self.use_insecure_key_cache:
raise RuntimeError(
"Insecure key cache cannot be used without enabling unsafe features"
)
message = "Insecure key cache cannot be used without enabling unsafe features"
raise RuntimeError(message)
if self.virtual:
raise RuntimeError(
"Virtual compilation is not allowed without enabling unsafe features"
)
message = "Virtual compilation is not allowed without enabling unsafe features"
raise RuntimeError(message)
if self.use_insecure_key_cache and self.insecure_key_cache_location is None:
raise RuntimeError(
"Insecure key cache cannot be enabled without specifying its location"
)
message = "Insecure key cache cannot be enabled without specifying its location"
raise RuntimeError(message)
# pylint: disable=too-many-arguments
@@ -119,7 +116,8 @@ class Configuration:
hints = get_type_hints(Configuration)
for name, value in kwargs.items():
if name not in hints:
raise TypeError(f"Unexpected keyword argument '{name}'")
message = f"Unexpected keyword argument '{name}'"
raise TypeError(message)
hint = hints[name]
expected = None
@@ -151,10 +149,11 @@ class Configuration:
if not is_correctly_typed:
if expected is None:
expected = hint.__name__ if hasattr(hint, "__name__") else str(hint)
raise TypeError(
message = (
f"Unexpected type for keyword argument '{name}' "
f"(expected '{expected}', got '{type(value).__name__}')"
)
raise TypeError(message)
setattr(result, name, value)

View File

@@ -53,11 +53,12 @@ def circuit(
)
if not (is_value or is_scalar_annotation):
raise ValueError(
message = (
f"Annotation {annotation} for argument '{name}' is not valid "
f"(please use a cnp type such as "
f"`cnp.uint4` or 'cnp.tensor[cnp.uint4, 3, 2]')"
)
raise ValueError(message)
parameter_values[name] = (
annotation if is_value else Value(annotation.dtype, shape=(), is_encrypted=False)

View File

@@ -169,7 +169,8 @@ class Server:
if via_mlir:
if self._mlir is None or self._configuration is None:
raise RuntimeError("Loaded server objects cannot be saved again via MLIR")
message = "Loaded server objects cannot be saved again via MLIR"
raise RuntimeError(message)
with tempfile.TemporaryDirectory() as tmp:
@@ -190,7 +191,8 @@ class Server:
return
if self._output_dir is None:
raise RuntimeError("Just-in-Time compilation cannot be saved")
message = "Just-in-Time compilation cannot be saved"
raise RuntimeError(message)
with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f:
f.write(self.client_specs.serialize())

View File

@@ -11,6 +11,8 @@ from ..dtypes import Float, Integer
from ..representation import Graph, Node, Operation
from .artifacts import DebugArtifacts
# ruff: noqa: ERA001
def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
"""
@@ -267,12 +269,12 @@ def find_single_lca(graph: Graph, nodes: List[Node]) -> Optional[Node]:
# find common ancestors among `nodes`
# if the single lca exists, it's in this set
common_ancestors = set(
common_ancestors = {
node
for node in nx_graph.nodes()
if node.operation != Operation.Constant
and all(node in ancestors for ancestors in all_ancestors)
)
}
# iterate over every node in the graph reversed topological order
# this is to ensure result, if found, is the single "lowest" common ancestor
@@ -344,9 +346,7 @@ def is_single_common_ancestor(
# - they don't affect fusability status
predecessor_count_in_subgraph = len(list(subgraph.predecessors(node)))
predecessor_count_in_nx_graph = len(
list(
pred for pred in nx_graph.predecessors(node) if pred.operation != Operation.Constant
)
[pred for pred in nx_graph.predecessors(node) if pred.operation != Operation.Constant]
)
# see if number of predecessors are different

View File

@@ -16,10 +16,11 @@ class Float(BaseDataType):
super().__init__()
if bit_width not in [16, 32, 64]:
raise ValueError(
message = (
f"Float({repr(bit_width)}) is not supported "
f"(bit width must be one of 16, 32 or 64)"
)
raise ValueError(message)
self.bit_width = bit_width

View File

@@ -58,7 +58,8 @@ class Integer(BaseDataType):
lower_bound = int(value.min())
upper_bound = int(value.max())
else:
raise ValueError(f"Integer cannot represent {repr(value)}")
message = f"Integer cannot represent {repr(value)}"
raise ValueError(message)
def bits_to_represent_int(value: int, force_signed: bool) -> int:
bits: int
@@ -92,10 +93,11 @@ class Integer(BaseDataType):
if not isinstance(bit_width, int) or bit_width <= 0:
integer_str = "SignedInteger" if is_signed else "UnsignedInteger"
raise ValueError(
message = (
f"{integer_str}({repr(bit_width)}) is not supported "
f"(bit width must be a positive integer)"
)
raise ValueError(message)
self.is_signed = is_signed
self.bit_width = bit_width

View File

@@ -44,7 +44,8 @@ def array(values: Any) -> Union[np.ndarray, Tracer]:
values[i] = Tracer.sanitize(value)
if not values[i].output.is_scalar:
raise ValueError("Encrypted arrays can only be created from scalars")
message = "Encrypted arrays can only be created from scalars"
raise ValueError(message)
dtype = combine_dtypes([value.output.dtype for value in values])
is_encrypted = True

View File

@@ -53,10 +53,11 @@ class AutoRounder:
def __init__(self, target_msbs: int = MAXIMUM_TLU_BIT_WIDTH):
# pylint: disable=protected-access
if local._is_adjusting:
raise RuntimeError(
message = (
"AutoRounders cannot be constructed during adjustment, "
"please construct AutoRounders outside the function and reference it"
)
raise RuntimeError(message)
# pylint: enable=protected-access
self.target_msbs = target_msbs
@@ -82,7 +83,8 @@ class AutoRounder:
pass
if local._is_adjusting:
raise RuntimeError("AutoRounders cannot be adjusted recursively")
message = "AutoRounders cannot be adjusted recursively"
raise RuntimeError(message)
try:
local._is_adjusting = True
@@ -111,7 +113,8 @@ class AutoRounder:
return
if rounder is None:
raise ValueError("AutoRounders cannot be adjusted with an empty inputset")
message = "AutoRounders cannot be adjusted with an empty inputset"
raise ValueError(message)
rounder.is_adjusted = True
@@ -189,11 +192,12 @@ def round_bit_pattern(
raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x)))
elif not lsbs_to_remove.is_adjusted:
raise RuntimeError(
message = (
"AutoRounders cannot be used before adjustment, "
"please call AutoRounder.adjust with the function that will be compiled "
"and provide the exact inputset that will be used for compilation"
)
raise RuntimeError(message)
lsbs_to_remove = lsbs_to_remove.lsbs_to_remove
@@ -229,11 +233,13 @@ def round_bit_pattern(
if isinstance(x, np.ndarray):
if not np.issubdtype(x.dtype, np.integer):
raise TypeError(
message = (
f"Expected input elements to be integers but they are {type(x.dtype).__name__}"
)
raise TypeError(message)
elif not isinstance(x, (int, np.integer)):
raise TypeError(f"Expected input to be an int or a numpy array but it's {type(x).__name__}")
message = f"Expected input to be an int or a numpy array but it's {type(x).__name__}"
raise TypeError(message)
return evaluator(x, lsbs_to_remove)

View File

@@ -53,7 +53,8 @@ class LookupTable:
self.output_dtype = Integer.that_can_represent([minimum, maximum])
if not is_valid:
raise ValueError(f"LookupTable cannot be constructed with {repr(table)}")
message = f"LookupTable cannot be constructed with {repr(table)}"
raise ValueError(message)
def __repr__(self):
return str(list(self.table))
@@ -63,17 +64,19 @@ class LookupTable:
return LookupTable.apply(key, self.table)
if not isinstance(key.output.dtype, Integer):
raise ValueError(f"LookupTable cannot be looked up with {key.output}")
message = f"LookupTable cannot be looked up with {key.output}"
raise ValueError(message)
table = self.table
if not np.issubdtype(self.table.dtype, np.integer):
try:
table = np.broadcast_to(table, key.output.shape)
except Exception as error:
raise ValueError(
message = (
f"LookupTable of shape {self.table.shape} "
f"cannot be looked up with {key.output}"
) from error
)
raise ValueError(message) from error
output = deepcopy(key.output)
output.dtype = self.output_dtype
@@ -114,13 +117,15 @@ class LookupTable:
if not isinstance(key, (int, np.integer, np.ndarray)) or (
isinstance(key, np.ndarray) and not np.issubdtype(key.dtype, np.integer)
):
raise ValueError(f"LookupTable cannot be looked up with {key}")
message = f"LookupTable cannot be looked up with {key}"
raise ValueError(message)
if np.issubdtype(table.dtype, np.integer):
return table[key]
if not isinstance(key, np.ndarray) or key.shape != table.shape:
raise ValueError(f"LookupTable of shape {table.shape} cannot be looked up with {key}")
message = f"LookupTable of shape {table.shape} cannot be looked up with {key}"
raise ValueError(message)
flat_result = np.fromiter(
(lt.table[k] for lt, k in zip(table.flat, key.flat)),

View File

@@ -59,7 +59,8 @@ def univariate(
output_value = Value.of(evaluation, is_encrypted=x.output.is_encrypted)
if output_value.shape != x.output.shape:
raise ValueError(f"Function {function.__name__} cannot be used with cnp.univariate")
message = f"Function {function.__name__} cannot be used with cnp.univariate"
raise ValueError(message)
# pylint: disable=protected-access
is_direct = Tracer._is_direct
@@ -67,11 +68,12 @@ def univariate(
if is_direct:
if outputs is None:
raise ValueError(
message = (
"Univariate extension requires "
"`outputs` argument for direct circuit definition "
"(e.g., cnp.univariate(function, outputs=cnp.uint4)(x))"
)
raise ValueError(message)
output_value.dtype = outputs if isinstance(outputs, BaseDataType) else outputs.dtype
computation = Node.generic(

View File

@@ -28,4 +28,5 @@ def unreachable():
Raise a RuntimeError to indicate unreachable code is entered.
"""
raise RuntimeError("Entered unreachable code")
message = "Entered unreachable code"
raise RuntimeError(message)

View File

@@ -217,10 +217,11 @@ class GraphConverter:
offending_nodes[node] = [reason, node.location]
if len(offending_nodes) != 0:
raise RuntimeError(
message = (
"Function you are trying to compile cannot be converted to MLIR\n\n"
+ graph.format(highlighted_nodes=offending_nodes)
)
raise RuntimeError(message)
@staticmethod
def _update_bit_widths(graph: Graph):
@@ -466,7 +467,7 @@ class GraphConverter:
assert len(node.inputs) in {1, 2}
if len(node.inputs) == 2:
if set(inp.is_scalar for inp in node.inputs) != {True, False}:
if {inp.is_scalar for inp in node.inputs} != {True, False}:
continue
else:
if not node.inputs[0].is_scalar:

View File

@@ -85,7 +85,8 @@ class NodeConverter:
# the branch above is always taken due to compatibility checks
# still, it's a good idea to raise an appropriate error, just in case
raise ValueError(f"{value} cannot be converted to MLIR") # pragma: no cover
message = f"{value} cannot be converted to MLIR" # pragma: no cover
raise ValueError(message) # pragma: no cover
@staticmethod
def mlir_name(result: OpResult) -> str:
@@ -181,6 +182,8 @@ class NodeConverter:
assert_that(self.node.converted_to_table_lookup)
return self._convert_tlu()
# pylint: disable=no-self-use
def _convert_add(self) -> OpResult:
"""
Convert "add" node to its corresponding MLIR representation.
@@ -365,7 +368,8 @@ class NodeConverter:
in-memory MLIR representation corresponding to `self.node`
"""
raise NotImplementedError("conv1d conversion to MLIR is not yet implemented")
message = "conv1d conversion to MLIR is not yet implemented"
raise NotImplementedError(message)
def _convert_conv2d(self) -> OpResult:
"""
@@ -425,7 +429,8 @@ class NodeConverter:
in-memory MLIR representation corresponding to `self.node`
"""
raise NotImplementedError("conv3d conversion to MLIR is not yet implemented")
message = "conv3d conversion to MLIR is not yet implemented"
raise NotImplementedError(message)
def _convert_dot(self) -> OpResult:
"""
@@ -487,7 +492,8 @@ class NodeConverter:
in-memory MLIR representation corresponding to `self.node`
"""
raise NotImplementedError("MaxPool operation cannot be compiled yet")
message = "MaxPool operation cannot be compiled yet"
raise NotImplementedError(message)
def _convert_mul(self) -> OpResult:
"""
@@ -1076,3 +1082,5 @@ class NodeConverter:
# pylint: enable=too-many-function-args
self.constant_cache[(mlir_type, mlir_attribute)] = result
return result
# pylint: enable=no-self-use

View File

@@ -2,6 +2,8 @@
Declaration of various `Evaluator` classes, to make graphs picklable.
"""
# ruff: noqa: ARG002
class ConstantEvaluator:
"""

View File

@@ -181,6 +181,7 @@ class Graph:
"""
# pylint: disable=too-many-branches,too-many-locals,too-many-statements
# ruff: noqa: ERA001
if self.is_direct:
show_bounds = False
@@ -396,7 +397,8 @@ class Graph:
}
except Exception as error:
raise RuntimeError(f"Bound measurement using inputset[{index}] failed") from error
message = f"Bound measurement using inputset[{index}] failed"
raise RuntimeError(message) from error
return bounds

View File

@@ -56,7 +56,8 @@ class Node:
try:
value = Value.of(constant)
except Exception as error:
raise ValueError(f"Constant {repr(constant)} is not supported") from error
message = f"Constant {repr(constant)} is not supported"
raise ValueError(message) from error
properties = {"constant": np.array(constant)}
return Node([], value, Operation.Constant, ConstantEvaluator(properties), properties)
@@ -199,26 +200,25 @@ class Node:
return result
if len(args) != len(self.inputs):
raise ValueError(
f"{generic_error_message()} failed because of invalid number of arguments"
)
message = f"{generic_error_message()} failed because of invalid number of arguments"
raise ValueError(message)
for arg, input_ in zip(args, self.inputs):
try:
arg_value = Value.of(arg)
except Exception as error:
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
raise ValueError(
f"{generic_error_message()} failed because {arg_str} is not valid"
) from error
message = f"{generic_error_message()} failed because {arg_str} is not valid"
raise ValueError(message) from error
if input_.shape != arg_value.shape:
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
raise ValueError(
message = (
f"{generic_error_message()} failed because "
f"{arg_str} does not have the expected "
f"shape of {input_.shape}"
)
raise ValueError(message)
result = self.evaluator(*args)
@@ -237,11 +237,12 @@ class Node:
pass
if not isinstance(result, (np.bool_, np.integer, np.floating, np.ndarray)):
raise ValueError(
message = (
f"{generic_error_message()} resulted in {repr(result)} "
f"of type {result.__class__.__name__} "
f"which is not acceptable either because of the type or because of overflow"
)
raise ValueError(message)
if isinstance(result, np.ndarray):
dtype = result.dtype
@@ -250,18 +251,20 @@ class Node:
and not np.issubdtype(dtype, np.floating)
and not np.issubdtype(dtype, np.bool_)
):
raise ValueError(
message = (
f"{generic_error_message()} resulted in {repr(result)} "
f"of type np.ndarray and of underlying type '{type(dtype).__name__}' "
f"which is not acceptable because of the underlying type"
)
raise ValueError(message)
if result.shape != self.output.shape:
raise ValueError(
message = (
f"{generic_error_message()} resulted in {repr(result)} "
f"which does not have the expected "
f"shape of {self.output.shape}"
)
raise ValueError(message)
return result

View File

@@ -93,11 +93,12 @@ class Tracer:
try:
sanitized_tracers.append(Tracer.sanitize(tracer))
except Exception as error:
raise ValueError(
message = (
f"Function '{function.__name__}' "
f"returned '{tracer}', "
f"which is not supported"
) from error
)
raise ValueError(message) from error
output_tracers = tuple(sanitized_tracers)
@@ -133,10 +134,10 @@ class Tracer:
assert_that(nx.algorithms.dag.is_directed_acyclic_graph(graph))
unique_edges = set(
unique_edges = {
(pred, succ, tuple((k, v) for k, v in edge_data.items()))
for pred, succ, edge_data in graph.edges(data=True)
)
}
assert_that(len(unique_edges) == len(graph.edges))
return graph
@@ -342,14 +343,16 @@ class Tracer:
"""
if operation not in Tracer.SUPPORTED_NUMPY_OPERATORS:
raise RuntimeError(f"Function 'np.{operation.__name__}' is not supported")
message = f"Function 'np.{operation.__name__}' is not supported"
raise RuntimeError(message)
supported_kwargs = Tracer.SUPPORTED_KWARGS.get(operation, set())
for kwarg in kwargs:
if kwarg not in supported_kwargs:
raise RuntimeError(
message = (
f"Function 'np.{operation.__name__}' is not supported with kwarg '{kwarg}'"
)
raise RuntimeError(message)
if operation == np.ones_like: # pylint: disable=comparison-with-callable
dtype = kwargs.get("dtype", np.int64)
@@ -422,7 +425,8 @@ class Tracer:
sanitized_args = [self.sanitize(arg) for arg in args]
return Tracer._trace_numpy_operation(ufunc, *sanitized_args, **kwargs)
raise RuntimeError("Only __call__ hook is supported for numpy ufuncs")
message = "Only __call__ hook is supported for numpy ufuncs"
raise RuntimeError(message)
def __array_function__(self, func, _types, args, kwargs):
"""
@@ -509,9 +513,10 @@ class Tracer:
if ndigits is None:
result = Tracer._trace_numpy_operation(np.around, self)
if self._is_direct:
raise RuntimeError(
message = (
"'round(x)' cannot be used in direct definition (you may use np.around instead)"
)
raise RuntimeError(message)
return result.astype(np.int64)
return Tracer._trace_numpy_operation(np.around, self, decimals=ndigits)
@@ -586,10 +591,11 @@ class Tracer:
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
output_value.dtype = dtype.dtype
else:
raise ValueError(
message = (
"`astype` method must be called with a concrete.numpy type "
"for direct circuit definition (e.g., value.astype(cnp.uint4))"
)
raise ValueError(message)
computation = Node.generic(
"astype",
@@ -600,10 +606,11 @@ class Tracer:
return Tracer(computation, [self])
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
raise ValueError(
message = (
"`astype` method must be called with a "
"numpy type for compilation (e.g., value.astype(np.int64))"
)
raise ValueError(message)
dtype = np.dtype(dtype).type
if np.issubdtype(dtype, np.integer) and dtype != np.int64:
@@ -622,9 +629,11 @@ class Tracer:
def evaluator(x, dtype):
if np.any(np.isnan(x)):
raise ValueError("A `NaN` value is tried to be converted to integer")
message = "A `NaN` value is tried to be converted to integer"
raise ValueError(message)
if np.any(np.isinf(x)):
raise ValueError("An `Inf` value is tried to be converted to integer")
message = "An `Inf` value is tried to be converted to integer"
raise ValueError(message)
return x.astype(dtype)
else:
@@ -716,9 +725,10 @@ class Tracer:
valid = False
if not valid:
raise ValueError(
message = (
f"Indexing with '{format_indexing_element(indexing_element)}' is not supported"
)
raise ValueError(message)
output_value = deepcopy(self.output)
output_value.shape = np.zeros(output_value.shape)[index].shape
@@ -761,9 +771,10 @@ class Tracer:
valid = False
if not valid:
raise ValueError(
message = (
f"Assigning to '{format_indexing_element(indexing_element)}' is not supported"
)
raise ValueError(message)
np.zeros(self.output.shape)[index] = 1
@@ -808,7 +819,7 @@ class Tracer:
return self.output.size
@property
def T(self) -> "Tracer": # pylint: disable=invalid-name
def T(self) -> "Tracer": # pylint: disable=invalid-name # noqa: N802
"""
Trace numpy.ndarray.T.
"""

View File

@@ -9,6 +9,7 @@ from ..values import Value
from .tracer import ScalarAnnotation, TensorAnnotation
# pylint: disable=function-redefined,invalid-name,no-self-use,too-many-lines,using-constant-test
# ruff: noqa
# We'll pull a little trick on mypy

View File

@@ -89,7 +89,8 @@ class Value:
if np.issubdtype(value.dtype, np.float16):
return Value(dtype=Float(16), shape=value.shape, is_encrypted=is_encrypted)
raise ValueError(f"Value cannot represent {repr(value)}")
message = f"Value cannot represent {repr(value)}"
raise ValueError(message)
# pylint: enable=too-many-branches,too-many-return-statements

View File

@@ -18,7 +18,7 @@ SUPPORTED_AUTO_PAD = {
}
# pylint: disable=too-many-branches
# pylint: disable=too-many-branches,too-many-statements
def conv(
@@ -66,51 +66,63 @@ def conv(
if kernel_shape is not None and (
(weight.ndim - 2) != len(kernel_shape) or not np.all(weight.shape[2:] == kernel_shape)
):
raise ValueError(f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}")
message = f"expected kernel_shape to be {weight.shape[2:]}, but got {kernel_shape}"
raise ValueError(message)
if isinstance(x, np.ndarray):
if not isinstance(weight, np.ndarray):
raise TypeError("expected weight to be of same type as x")
message = "expected weight to be of same type as x"
raise TypeError(message)
if bias is not None and not isinstance(bias, np.ndarray):
raise TypeError("expected bias to be of same type as x")
message = "expected bias to be of same type as x"
raise TypeError(message)
elif isinstance(x, Tracer):
if not isinstance(weight, (Tracer, np.ndarray)):
raise TypeError("expected weight to be of type Tracer or ndarray")
message = "expected weight to be of type Tracer or ndarray"
raise TypeError(message)
if bias is not None and not isinstance(bias, (Tracer, np.ndarray)):
raise TypeError("expected bias to be of type Tracer or ndarray")
message = "expected bias to be of type Tracer or ndarray"
raise TypeError(message)
if x.ndim <= 2:
raise ValueError(
message = (
f"expected input x to have at least 3 dimensions (N, C, D1, ...), but got {x.ndim}"
)
raise ValueError(message)
if weight.ndim <= 2:
raise ValueError(
message = (
f"expected weight to have at least 3 dimensions (F, C / group, K1, ...), but got "
f"{weight.ndim}"
)
raise ValueError(message)
if bias is not None and bias.ndim != 1:
raise ValueError(f"expected bias to have a single dimension (F,), but got {bias.ndim}")
message = f"expected bias to have a single dimension (F,), but got {bias.ndim}"
raise ValueError(message)
if not isinstance(group, int) or group <= 0:
raise ValueError(f"expected group to be an integer > 0, but got {group}")
message = f"expected group to be an integer > 0, but got {group}"
raise ValueError(message)
if auto_pad not in SUPPORTED_AUTO_PAD:
raise ValueError(f"auto_pad should be in {SUPPORTED_AUTO_PAD}, but got {repr(auto_pad)}")
message = f"auto_pad should be in {SUPPORTED_AUTO_PAD}, but got {repr(auto_pad)}"
raise ValueError(message)
n_channels = x.shape[1]
if weight.shape[1] != n_channels / group:
raise ValueError(
message = (
f"expected number of channel in weight to be {n_channels / group} (C / group), but got "
f"{weight.shape[1]}"
)
raise ValueError(message)
if weight.shape[0] % group != 0:
raise ValueError(
message = (
f"expected number of feature maps ({weight.shape[0]}) to be a multiple of group "
f"({group})"
)
raise ValueError(message)
dims = x.ndim - 2
if dims == 1:
@@ -155,7 +167,9 @@ def conv(
group=group,
auto_pad=auto_pad,
)
raise NotImplementedError("only 1D, 2D, and 3D convolutions are supported")
message = "only 1D, 2D, and 3D convolutions are supported"
raise NotImplementedError(message)
# pylint: enable=too-many-branches
@@ -206,21 +220,24 @@ def _conv1d(
)
if len(pads) != 2:
raise ValueError(
message = (
f"pads should be of form "
f"(D_begin_pad, D_end_pad) when performing "
f"1D convolution, but it's {pads}"
)
raise ValueError(message)
if len(strides) != 1:
raise ValueError(
message = (
f"strides should be of form (D_stride,) when performing 1D "
f"convolution, but it's {strides}"
)
raise ValueError(message)
if len(dilations) != 1:
raise ValueError(
message = (
f"dilations should be of form (D_dilation,) when performing 1D "
f"convolution, but it's {dilations}"
)
raise ValueError(message)
return _trace_or_eval(x, weight, bias, pads, strides, dilations, group)
@@ -270,21 +287,24 @@ def _conv2d(
)
if len(pads) != 4:
raise ValueError(
message = (
f"pads should be of form "
f"(height_begin_pad, width_begin_pad, height_end_pad, width_end_pad) when performing "
f"2D convolution, but it's {pads}"
)
raise ValueError(message)
if len(strides) != 2:
raise ValueError(
message = (
f"strides should be of form (height_stride, width_stride) when performing 2D "
f"convolution, but it's {strides}"
)
raise ValueError(message)
if len(dilations) != 2:
raise ValueError(
message = (
f"dilations should be of form (height_dilation, width_dilation) when performing 2D "
f"convolution, but it's {dilations}"
)
raise ValueError(message)
return _trace_or_eval(x, weight, bias, pads, strides, dilations, group)
@@ -334,22 +354,25 @@ def _conv3d(
)
if len(pads) != 6:
raise ValueError(
message = (
f"pads should be of form "
f"(D_begin_pad, height_begin_pad, width_begin_pad, "
f"D_end_pad, height_end_pad, width_end_pad) when performing "
f"3D convolution, but it's {pads}"
)
raise ValueError(message)
if len(strides) != 3:
raise ValueError(
message = (
f"strides should be of form (D_stride, height_stride, width_stride) when performing "
f"3D convolution, but it's {strides}"
)
raise ValueError(message)
if len(dilations) != 3:
raise ValueError(
message = (
f"dilations should be of form (D_dilation, height_dilation, width_dilation) when "
f"performing 3D convolution, but it's {dilations}"
)
raise ValueError(message)
return _trace_or_eval(x, weight, bias, pads, strides, dilations, group)
@@ -632,11 +655,12 @@ def _evaluate_conv(
torch_padding = []
for dim in range(n_dim):
if pads[dim] != pads[n_dim + dim]:
raise ValueError(
message = (
f"padding should be the same for the beginning of the dimension and its end, but "
f"got {pads[dim]} in the beginning, and {pads[n_dim + dim]} at the end for "
f"dimension {dim}"
)
raise ValueError(message)
torch_padding.append(pads[dim])
dtype = (

View File

@@ -121,21 +121,22 @@ def maxpool(
value = tuple(value)
if not isinstance(value, tuple):
raise TypeError(
message = (
f"Expected {value_name} to be a tuple or a list but it's {type(value).__name__}"
)
raise TypeError(message)
for element in value:
if not isinstance(element, int):
raise TypeError(
message = (
f"Expected {value_name} to consist of integers "
f"but it has an element of type {type(element).__name__}"
)
raise TypeError(message)
if len(value) != size:
raise ValueError(
f"Expected {value_name} to have {size} elements but it has {len(value)}"
)
message = f"Expected {value_name} to have {size} elements but it has {len(value)}"
raise ValueError(message)
return value
@@ -153,24 +154,28 @@ def maxpool(
or np.issubdtype(x.dtype, np.floating)
or np.issubdtype(x.dtype, np.bool_)
):
raise TypeError(
message = (
f"Expected input elements to be of type np.integer, np.floating, or np.bool_ "
f"but it's {type(x.dtype).__name__}"
)
raise TypeError(message)
elif not isinstance(x, Tracer):
raise TypeError(
message = (
f"Expected input to be of type np.ndarray or Tracer "
f"but it's {type(auto_pad).__name__}"
)
raise TypeError(message)
if x.ndim < 3:
raise ValueError(
message = (
f"Expected input to have at least 3 dimensions (N, C, D1, ...) "
f"but it only has {x.ndim}"
)
raise ValueError(message)
if x.ndim > 5:
raise NotImplementedError(f"{x.ndim - 2}D maximum pooling is not supported yet")
message = f"{x.ndim - 2}D maximum pooling is not supported yet"
raise NotImplementedError(message)
# check kernel_shape
@@ -188,21 +193,20 @@ def maxpool(
# check auto_pad
if not isinstance(auto_pad, str):
raise TypeError(
f"Expected auto_pad to be of type str but it's {type(auto_pad).__name__}",
)
message = f"Expected auto_pad to be of type str but it's {type(auto_pad).__name__}"
raise TypeError(message)
if auto_pad not in AVAILABLE_AUTO_PAD:
raise ValueError(
message = (
f"Expected auto_pad to be one of "
f"{', '.join(sorted(AVAILABLE_AUTO_PAD))} "
f"but it's {auto_pad}",
f"but it's {auto_pad}"
)
raise ValueError(message)
if auto_pad not in SUPPORTED_AUTO_PAD:
raise NotImplementedError(
f"Desired auto_pad of {auto_pad} is not supported yet",
)
message = f"Desired auto_pad of {auto_pad} is not supported yet"
raise NotImplementedError(message)
# check pads
@@ -215,9 +219,8 @@ def maxpool(
pad_begin = pads[i]
pad_end = pads[i + len(pads) // 2]
if pad_begin != pad_end:
raise NotImplementedError(
f"Desired pads of {pads} is not supported yet because of uneven padding"
)
message = f"Desired pads of {pads} is not supported yet because of uneven padding"
raise NotImplementedError(message)
# check dilations
@@ -229,40 +232,40 @@ def maxpool(
# check ceil_mode
if not isinstance(ceil_mode, int):
raise TypeError(
f"Expected ceil_mode to be of type int but it's {type(ceil_mode).__name__}",
)
message = f"Expected ceil_mode to be of type int but it's {type(ceil_mode).__name__}"
raise TypeError(message)
if ceil_mode not in AVAILABLE_CEIL_MODE:
raise ValueError(
message = (
f"Expected ceil_mode to be one of "
f"{', '.join(sorted(str(x) for x in AVAILABLE_CEIL_MODE))} "
f"but it's {ceil_mode}",
f"but it's {ceil_mode}"
)
raise ValueError(message)
if ceil_mode not in SUPPORTED_CEIL_MODE:
raise NotImplementedError(
f"Desired ceil_mode of {ceil_mode} is not supported yet",
)
message = f"Desired ceil_mode of {ceil_mode} is not supported yet"
raise NotImplementedError(message)
# check storage_order
if not isinstance(storage_order, int):
raise TypeError(
f"Expected storage_order to be of type int but it's {type(storage_order).__name__}",
message = (
f"Expected storage_order to be of type int but it's {type(storage_order).__name__}"
)
raise TypeError(message)
if storage_order not in AVAILABLE_STORAGE_ORDER:
raise ValueError(
message = (
f"Expected storage_order to be one of "
f"{', '.join(sorted(str(x) for x in AVAILABLE_STORAGE_ORDER))} "
f"but it's {storage_order}",
f"but it's {storage_order}"
)
raise ValueError(message)
if storage_order not in SUPPORTED_STORAGE_ORDER:
raise NotImplementedError(
f"Desired storage_order of {storage_order} is not supported yet",
)
message = f"Desired storage_order of {storage_order} is not supported yet"
raise NotImplementedError(message)
# trace or evaluate
return _trace_or_evaluate(x, kernel_shape, strides, pads, dilations, ceil_mode == 1)

View File

@@ -100,7 +100,7 @@ Traceback (most recent call last):
File "/home/default/Documents/Projects/Zama/hdk/concrete/numpy/mlir/graph_converter.py", line 298, in convert
GraphConverter._check_graph_convertibility(graph)
File "/home/default/Documents/Projects/Zama/hdk/concrete/numpy/mlir/graph_converter.py", line 175, in _check_graph_convertibility
raise RuntimeError(
raise RuntimeError(message)
RuntimeError: Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<uint4>

69
poetry.lock generated
View File

@@ -125,21 +125,22 @@ files = [
[[package]]
name = "attrs"
version = "22.1.0"
version = "22.2.0"
description = "Classes Without Boilerplate"
category = "dev"
optional = false
python-versions = ">=3.5"
python-versions = ">=3.6"
files = [
{file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"},
{file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"},
{file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"},
{file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"},
]
[package.extras]
dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"]
docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"]
tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"]
tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"]
dev = ["attrs[docs,tests]"]
docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"]
tests = ["attrs[tests-no-zope]", "zope.interface"]
tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"]
[[package]]
name = "backcall"
@@ -955,14 +956,14 @@ files = [
[[package]]
name = "ipywidgets"
version = "8.0.3"
version = "8.0.4"
description = "Jupyter interactive widgets"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "ipywidgets-8.0.3-py3-none-any.whl", hash = "sha256:db7dd35fb1217636cbdbe0ba0bd2216d91a7695cb28b5c1dca17e62cd51378de"},
{file = "ipywidgets-8.0.3.tar.gz", hash = "sha256:2ec50df8538a1d4ddd5d454830d010922ad1015e81ac23efb27c0908bbc1eece"},
{file = "ipywidgets-8.0.4-py3-none-any.whl", hash = "sha256:ebb195e743b16c3947fe8827190fb87b4d00979c0fbf685afe4d2c4927059fa1"},
{file = "ipywidgets-8.0.4.tar.gz", hash = "sha256:c0005a77a47d77889cafed892b58e33b4a2a96712154404c6548ec22272811ea"},
]
[package.dependencies]
@@ -977,14 +978,14 @@ test = ["jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"]
[[package]]
name = "isort"
version = "5.11.3"
version = "5.11.4"
description = "A Python utility / library to sort Python imports."
category = "dev"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "isort-5.11.3-py3-none-any.whl", hash = "sha256:83155ffa936239d986b0f190347a3f2285f42a9b9e1725c89d865b27dd0627e5"},
{file = "isort-5.11.3.tar.gz", hash = "sha256:a8ca25fbfad0f7d5d8447a4314837298d9f6b23aed8618584c894574f626b64b"},
{file = "isort-5.11.4-py3-none-any.whl", hash = "sha256:c033fd0edb91000a7f09527fe5c75321878f98322a77ddcc81adbd83724afb7b"},
{file = "isort-5.11.4.tar.gz", hash = "sha256:6db30c5ded9815d813932c04c2f85a360bcdd35fed496f4d8f35495ef0a261b6"},
]
[package.extras]
@@ -1224,14 +1225,14 @@ files = [
[[package]]
name = "jupyterlab-widgets"
version = "3.0.4"
version = "3.0.5"
description = "Jupyter interactive widgets for JupyterLab"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "jupyterlab_widgets-3.0.4-py3-none-any.whl", hash = "sha256:4c9275daa6d20fc96c3aea45756ece7110850d035b0b93a6a40e918016b927da"},
{file = "jupyterlab_widgets-3.0.4.tar.gz", hash = "sha256:9a568e022b8bb53ab23291f6ddb52f8002b789c2c5763378cbc882be1d619be8"},
{file = "jupyterlab_widgets-3.0.5-py3-none-any.whl", hash = "sha256:a04a42e50231b355b7087e16a818f541e53589f7647144ea0344c4bf16f300e5"},
{file = "jupyterlab_widgets-3.0.5.tar.gz", hash = "sha256:eeaecdeaf6c03afc960ddae201ced88d5979b4ca9c3891bcb8f6631af705f5ef"},
]
[[package]]
@@ -3052,6 +3053,32 @@ files = [
[package.extras]
idna2008 = ["idna"]
[[package]]
name = "ruff"
version = "0.0.191"
description = "An extremely fast Python linter, written in Rust."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.0.191-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:77aa90ab83f6ef663ad002708eec13b7931193dfa418c09564ab34df4766a11d"},
{file = "ruff-0.0.191-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:75e8d628f4d1b0db216e8cb26a10f36a6c3db572e839aaad8ac0a51e0bd1c0ef"},
{file = "ruff-0.0.191-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ce09e5c6253854ce907a42640c316f97ff49b8a3e35ce3e44524e1fc67bf3f3"},
{file = "ruff-0.0.191-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0a230dad3805d70d4c2c9716e1cfe0939695979c8ef7b7b4791e3f5c26c5d965"},
{file = "ruff-0.0.191-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2727d02f9cacca0d945d1a4fa443ff9469a131c7791df1f8824b9c89721bd138"},
{file = "ruff-0.0.191-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:20b52da4ff1008f4401e4681c6b8552133aa0553cde17dfcd4d42d992242e578"},
{file = "ruff-0.0.191-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1084aa3a448c5f98c567547de79eef2dba6f022232fd9c89876881d74de3d3cd"},
{file = "ruff-0.0.191-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d76b88b0dbd6dfa6ac3ceb08bd519e140ad6e4f7cf9cf2d88310a084af35e92"},
{file = "ruff-0.0.191-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:614106b64fefd70750a2eb3e50bfce7c9b56585d7be96f925fb3c13108552984"},
{file = "ruff-0.0.191-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:304b83f1ab6f91245a800b3c60e38083b89ee240cbdb3d1d8a98671265292c11"},
{file = "ruff-0.0.191-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4bf38f1419b5d8b3bc299a2cba183a43ade4858af2160033fb94aa9e5a85c700"},
{file = "ruff-0.0.191-py3-none-musllinux_1_2_i686.whl", hash = "sha256:97c22a00daad210f03458e8993c62bee38ec98d2f0b8d50a671b31a96256cd5f"},
{file = "ruff-0.0.191-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c2a5b7e2b7aa33e5ece2291dc655ad2023626114179fc166d382e3902f319dd5"},
{file = "ruff-0.0.191-py3-none-win32.whl", hash = "sha256:195da807f65e1b153379b7258086cbfb3b10143ecf93683275dfef561593c485"},
{file = "ruff-0.0.191-py3-none-win_amd64.whl", hash = "sha256:a25a7b9d56732df7f4887f3bd4a66dd54924be8c4d359e6a73bd5b1a8082f1bf"},
{file = "ruff-0.0.191.tar.gz", hash = "sha256:d698c4d5e3b2963cbbb7c2728f404091d5c47cdf8d94db3eb2f335e2a93a6b1b"},
]
[[package]]
name = "secretstorage"
version = "3.3.3"
@@ -3511,14 +3538,14 @@ test = ["pytest (>=3.0.0)"]
[[package]]
name = "widgetsnbextension"
version = "4.0.4"
version = "4.0.5"
description = "Jupyter interactive widgets for Jupyter Notebook"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "widgetsnbextension-4.0.4-py3-none-any.whl", hash = "sha256:fa0e840719ec95dd2ec85c3a48913f1a0c29d323eacbcdb0b29bfed0cc6da678"},
{file = "widgetsnbextension-4.0.4.tar.gz", hash = "sha256:44c69f18237af0f610557d6c1c7ef76853f5856a0e604c0a517f2320566bb775"},
{file = "widgetsnbextension-4.0.5-py3-none-any.whl", hash = "sha256:eaaaf434fb9b08bd197b2a14ffe45ddb5ac3897593d43c69287091e5f3147bf7"},
{file = "widgetsnbextension-4.0.5.tar.gz", hash = "sha256:003f716d930d385be3fd9de42dd9bf008e30053f73bddde235d14fbeaeff19af"},
]
[[package]]
@@ -3601,4 +3628,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
[metadata]
lock-version = "2.0"
python-versions = ">=3.7,<3.11"
content-hash = "7d1b41b63f4e8a218e5d5708a822387f9500dae55aa71f84966b0ff329749d78"
content-hash = "a0ee23721d05e022e48ec1d4de5c7d955d0d896f82b580fef0fe29ed77372e4a"

View File

@@ -73,6 +73,7 @@ pip-licenses = "^3.5.3"
pip-audit = "^1.1.1"
pytest-codeblocks = "^0.12.2"
twine = "^3.7.1"
ruff = "^0.0.191"
[build-system]
requires = ["poetry-core>=1.0.0"]
@@ -90,3 +91,20 @@ filterwarnings = [
[tool.semantic_release]
version_toml = "pyproject.toml:tool.poetry.version"
upload_to_pypi = "False"
[tool.ruff]
target-version = "py37"
line-length = 100
select = [
"F", "E", "W", "C90", "I", "UP", "N", "YTT", "S", "BLE", "FBT", "B", "C4",
"T10", "EM", "ICN", "Q", "RET", "SIM", "TID", "ARG", "DTZ", "ERA", "PD", "PGH",
"PLC", "PLE", "PLR", "PLW", "RUF"
]
ignore = [
"A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100",
"RET504", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003"
]
[tool.ruff.per-file-ignores]
"**/__init__.py" = ["F401"]
"tests/**" = ["PLC2201"]

View File

@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
"""Helper script for github actions to combine job statuses"""
import argparse
import json

View File

@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
"""Helper script for github actions"""
import argparse
import json

View File

@@ -60,7 +60,8 @@ def main(args):
if not VersionInfo.isvalid(strip_leading_v(version))
]
if len(invalid_versions) > 0:
raise RuntimeError(f"Found invalid versions: {invalid_versions}")
message = f"Found invalid versions: {invalid_versions}"
raise RuntimeError(message)
version_html = None
version_html_file_path = Path(args.versions_html_file).resolve()
@@ -68,13 +69,15 @@ def main(args):
version_html = BeautifulSoup(f, "html.parser")
if version_html is None:
raise RuntimeError(f"An error occured while trying to load {str(version_html_file_path)}")
message = f"An error occured while trying to load {str(version_html_file_path)}"
raise RuntimeError(message)
print(version_html)
version_list = version_html.find(id=VERSIONS_LIST_ID)
if version_list is None or version_list.name != "ul":
raise RuntimeError(f"Could not find <ul> tag with id {VERSIONS_LIST_ID}")
message = f"Could not find <ul> tag with id {VERSIONS_LIST_ID}"
raise RuntimeError(message)
non_semver_versions = {}
semver_versions = {}
@@ -103,10 +106,11 @@ def main(args):
versions_already_in_html = set(parsed_versions).intersection(semver_versions.keys())
if len(versions_already_in_html) > 0:
raise RuntimeError(
message = (
"The following versions are already in the html: "
f"{', '.join(str(ver) for ver in sorted(versions_already_in_html))}"
)
raise RuntimeError(message)
semver_versions.update(
(

View File

@@ -20,16 +20,16 @@ def main(args):
prerelease = args.prerelease
if not VersionInfo.isvalid(strip_leading_v(version)):
raise RuntimeError(f"Invalid version: {version}")
message = f"Invalid version: {version}"
raise RuntimeError(message)
version_json_file_path = Path(args.versions_json_file).resolve()
try:
with open(version_json_file_path, "r", encoding="utf-8") as f:
version_json = json.loads(f.read())
except JSONDecodeError as err:
raise RuntimeError(
f"An error occurred while trying to load {str(version_json_file_path)}"
) from err
message = f"An error occurred while trying to load {str(version_json_file_path)}"
raise RuntimeError(message) from err
# Version json is composed by:
# all: list of all published versions

View File

@@ -36,7 +36,6 @@ def main(args):
report.write("Found the following vulnerabilities:\n")
assert len(json_content) == 1
json_data = json.loads(json_content[0])
# print(json.dumps(json_data, indent=4))
for entry in json_data:
vuln_entries = entry.get("vulns", [])
if vuln_entries:

View File

@@ -11,7 +11,7 @@ def main(args):
pylintrc_file_path = Path(args.pylintrc_path).resolve()
config = configparser.ConfigParser()
config.read(pylintrc_file_path)
notes = sorted(map(lambda x: x.strip(), config["MISCELLANEOUS"]["notes"].split(",")))
notes = sorted(x.strip() for x in config["MISCELLANEOUS"]["notes"].split(","))
# Make sure we at least have todo in there without writing it otherwise we'll match
notes.append("TO" + "DO")
notes_for_grep_search = r"\|".join(notes)

View File

@@ -55,20 +55,22 @@ def get_code_blocks_for_file(md_file: Path) -> Dict[int, List[str]]:
line_idx, line = next(file_content_iterator)
if line == "":
# Reached EOF
raise SyntaxError(
message = (
"Reached EOF before finding the end of the current python block in "
f"{str(md_file)}"
)
raise SyntaxError(message)
if line.strip() == BLOCK_END:
break
if line.startswith(BLOCK_STARTS):
raise SyntaxError(
message = (
f"Error at line {line_idx} in file {str(md_file)}, "
"python block was opened before the previous one was "
"closed (missing ``` ?)"
)
raise SyntaxError(message)
current_python_code.append(line)
python_block_continues = False
else:
@@ -112,7 +114,7 @@ def main(args):
print(f"Testing block starting line #{line_idx} from {md_file_str}")
python_code = "".join(python_code)
compiled_code = compile(python_code, filename=md_file_str, mode="exec")
exec(compiled_code, {"__MODULE__": "__main__"})
exec(compiled_code, {"__MODULE__": "__main__"}) # noqa: S102
print("Success")
except Exception:
print("Failed")

View File

@@ -118,7 +118,8 @@ def set_version(args):
version_str = strip_leading_v(args.version)
if not VersionInfo.isvalid(version_str):
raise RuntimeError(f"Unable to validate version: {args.version}")
message = f"Unable to validate version: {args.version}"
raise RuntimeError(message)
file_vars_set = load_file_vars_set(args.pyproject_file, args.file_vars)
@@ -132,7 +133,8 @@ def set_version(args):
elif file_path.suffix == ".toml":
update_variable_in_toml_file(file_path, var_name, version_str)
else:
raise RuntimeError(f"Unsupported file extension: {file_path.suffix}")
message = f"Unsupported file extension: {file_path.suffix}"
raise RuntimeError(message)
def get_variable_from_py_file(file_path: Path, var_name: str):

View File

@@ -257,8 +257,7 @@ class Helpers:
break
if i == retries - 1:
raise AssertionError(
f"""
message = f"""
Expected Output
===============
@@ -269,7 +268,7 @@ Actual Output
{actual}
"""
)
raise AssertionError(message)
@staticmethod
def check_str(expected: str, actual: str):