Compare commits

...

3 Commits

Author SHA1 Message Date
dante
fdc93b9572 fix: handle [] shapes in sort 2025-03-07 14:32:10 -05:00
dante
f631445e26 docs: document arguments better (#950) 2025-03-05 16:10:50 -05:00
dante
fcbb27677f fix: empty dim len can be 1 (#949) 2025-02-28 23:56:19 -05:00
5 changed files with 1110 additions and 123 deletions

File diff suppressed because one or more lines are too long

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@@ -926,6 +926,9 @@ impl<T: Clone + TensorType> Tensor<T> {
));
}
self.dims = vec![];
}
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
self.dims = Vec::from(new_dims);
} else {
let product = if new_dims != [0] {
new_dims.iter().product::<usize>()
@@ -1104,6 +1107,10 @@ impl<T: Clone + TensorType> Tensor<T> {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
} else if self.dims() == &[0] && shape.iter().product::<usize>() == 1 {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
}
if self.dims().len() > shape.len() {

View File

@@ -1342,9 +1342,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Gets the total number of elements in the tensor
pub fn len(&self) -> usize {
match self {
ValTensor::Value { dims, .. } => {
ValTensor::Value { dims, inner, .. } => {
if !dims.is_empty() && (dims != &[0]) {
dims.iter().product::<usize>()
} else if dims.is_empty() {
inner.inner.len()
} else {
0
}