"""
Abstract base class for binary tensors.
"""
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from numbers import Integral
from typing import Any, Literal, TypeVar, overload
from typing_extensions import Self
import numpy as np
import numpy.typing as npt
from typing_validation import validate
from .vectorized import BinTensor
Shape = tuple[int, ...]
""" Type alias for tensor shapes. """
_ScalarType = TypeVar("_ScalarType", bound=np.generic)
[docs]
class bintensor:
r"""
Abstract base class for mutable binary tensors.
"""
[docs]
class ShapeError(ValueError):
"""
Specialised :obj:`ValueError` subclass for shape errors.
"""
[docs]
class ReadonlyError(ValueError):
"""
Specialised :obj:`ValueError` subclass for when attempting to mutate
a readonly tensor.
"""
[docs]
@classmethod
def make_readonly(cls, *tensors: bintensor) -> None:
"""
Makes all given tensors readonly.
"""
for t in tensors:
t.readonly = True
@classmethod
def _validate_data(cls, data: BinTensor) -> Literal[True]:
"""
Class method enforcing data validity.
Can be overridden by subclasses to include additional checks.
"""
if not np.all((data == 0) | (data == 1)):
raise ValueError("Bin array data value must be all 0 or 1.")
return True
_shape: tuple[int, ...]
_data: BinTensor
_readonly: bool
__hash: int
__is_zero: bool
__slots__ = (
"__weakref__",
"_shape",
"_data",
"_readonly",
"__hash",
"__is_zero",
)
[docs]
def __new__(
cls, data: npt.ArrayLike, *, readonly: bool = False, copy: bool = False
) -> Self:
"""
Creates a new tensor with given shape from the given binary data.
The class :class:`BinTensor` cannot be instantiated directly.
If ``readonly=True``, the resulting tensor and its data are readonly.
If ``copy=True``, a fresh copy of the given data is used.
.. warning::
The internal logic of the :class:`BinTensor` class and its
subclasses presumes that the given data will not be mutated
externally to the :class:`BinTensor` object after construction.
If a fresh copy is needed, pass ``copy=True`` at construction.
:meta public:
"""
assert cls is not bintensor, "Can't instantiate class 'bintensor'."
if not isinstance(data, np.ndarray) or data.dtype != np.uint8:
data = np.array(data, dtype=np.uint8)
copy = False
assert cls._validate_data(data)
if copy:
data = data.copy()
if readonly:
data.flags.writeable = False
data = data.view()
instance = super().__new__(cls)
instance._shape = data.shape
instance._data = data
instance._readonly = readonly
return instance
def __getnewargs_ex__(self) -> tuple[tuple[BinTensor], dict[str, Any]]:
"""
Method for pickling.
"""
return (self._data,), {"readonly": self._readonly}
# TODO: data compression on pickling
# TODO: data decompression on unpickling
# Move the __bytes__ and from_bytes logic from binvec to here,
# reshape/linearise tensor data and include shape information in the bytes
# as a header: shape length (uint8), shape (tuple of varints), data (bytes)
# https://github.com/multiformats/unsigned-varint
@property
def shape(self) -> Shape:
"""
The shape of this bit tensor.
"""
return self._shape
@property
def data(self) -> BinTensor:
"""
The underlying binary data.
"""
return self._data
@property
def readonly(self) -> bool:
"""
Whether the tensor is readonly.
"""
return self._readonly
@readonly.setter
def readonly(self, value: Literal[True]) -> None:
"""
Makes the tensor readonly.
"""
assert validate(value, Literal[True])
if self._readonly:
return
data = self._data
data.flags.writeable = False
self._data = data.view()
self._readonly = True
@property
def is_zero(self) -> bool:
"""
Whether the tensor is the constant zero tensor.
"""
try:
return self.__is_zero
except AttributeError:
self.__is_zero = (is_zero := bool(np.all(self._data == 0)))
return is_zero
[docs]
def copy(self, *, readonly: bool = False) -> Self:
"""
Returns a copy of this tensor.
If ``readonly=True``, the resulting copy is readonly.
"""
return type(self)(self._data.copy(), readonly=readonly)
def _postprocess_mutation(self) -> None:
"""
Method called when tensor data is potentially mutated.
Can be overridden by subclasses to avoid stale cache values.
"""
assert (
not self._readonly
), "You have mutated a readonly tensor, data integrity not guaranteed."
self._shape = self._data.shape
def _same_shape(self, other: bintensor, op: str) -> Literal[True]:
"""
Enforces that this and the other given tensor have the same shape.
The ``op`` argument refers to the operation being performed,
and is used by the error message.
"""
if self._shape != other._shape:
ss, os = self._shape, other._shape
raise bintensor.ShapeError(
f"unsupported operand shapes for {op}: {ss} and {os}"
)
return True
def _same_type(self, other: bintensor, op: str) -> Literal[True]:
"""
Enforces that the other tensor is an instance of this tensor's class.
The ``op`` argument refers to the operation being performed,
and is used by the error message.
"""
cls = type(self)
if not isinstance(other, cls):
st_str = cls.__name__
ot_str = type(other).__name__
raise TypeError(
f"unsupported operand types for {op}: {st_str} and {ot_str!r}"
)
return True
[docs]
def __pos__(self) -> Self:
"""
Returns the tensor unchanged.
:meta public:
"""
return self
[docs]
def __neg__(self) -> Self:
"""
Componentwise mod 2 negation: returns the tensor unchanged.
:meta public:
"""
return self
[docs]
def __add__(self, other: Self | Integral) -> Self:
"""
Componentwise mod 2 addition (bitwise XOR).
:raises ShapeError: if the tensors have different shapes.
:meta public:
"""
cls = type(self)
if isinstance(other, Integral):
return cls(self._data ^ (int(other) % 2))
if not isinstance(other, cls):
return NotImplemented
assert self._same_shape(other, "+")
return cls(self._data ^ other._data)
[docs]
def __sub__(self, other: Self | Integral) -> Self:
"""
Alias for :meth:`__add__`.
:meta public:
"""
return self + other
[docs]
def __mul__(self, other: Self | Integral) -> Self:
"""
Componentwise mod 2 multiplication (bitwise AND).
:raises ShapeError: if the tensors have different shapes.
:meta public:
"""
cls = type(self)
if isinstance(other, Integral):
return cls(self._data * (int(other) % 2))
if not isinstance(other, cls):
return NotImplemented
assert self._same_shape(other, "*")
return cls(self._data * other._data)
[docs]
def __iadd__(self, other: Self | Integral) -> Self:
"""
Inplace componentwise mod 2 addition (bitwise XOR).
:raises ShapeError: if the tensors have different shapes.
:meta public:
"""
if self._readonly:
raise bintensor.ReadonlyError("Tensor is read-only.")
if isinstance(other, Integral):
self._data ^= int(other) % 2
else:
assert self._same_type(other, "+=")
assert self._same_shape(other, "+=")
self._data ^= other._data
self._postprocess_mutation()
return self
[docs]
def __isub__(self, other: Self | Integral) -> Self:
"""
Alias for :meth:`__iadd__`.
:meta public:
"""
return self.__iadd__(other)
[docs]
def __imul__(self, other: Self | Integral) -> Self:
"""
Inplace componentwise mod 2 multiplication (bitwise AND).
:raises ShapeError: if the tensors have different shapes.
:meta public:
"""
if self._readonly:
raise bintensor.ReadonlyError("Tensor is read-only.")
if isinstance(other, Integral):
self._data *= int(other) % 2
else:
assert self._same_type(other, "*=")
assert self._same_shape(other, "*=")
self._data *= other._data
self._postprocess_mutation()
return self
def __eq__(self, other: Any) -> bool:
# if other in (0, 1):
# return bool(np.all(self._data == other))
if type(self) != type(other):
return NotImplemented
return self._shape == other._shape and bool(
np.all(self._data == other._data)
)
def __repr__(self) -> str:
s = repr(self._data)
assert s.startswith("array"), s
if s.endswith(", dtype=uint8)"):
end = -14
else:
assert s.endswith("dtype=uint8)"), s
end = -12
cls = type(self)
cls_name = cls.__name__
indent = " " * (len(cls_name) + len(self.shape))
s = cls_name + s[5:end] + ")"
return "\n".join(
indent + line.strip() if idx > 0 else line
for idx, line in enumerate(s.split("\n"))
)
def __hash__(self) -> int:
try:
return self.__hash
except AttributeError:
if not self._readonly:
raise TypeError(
"Only readonly tensors can be hashed."
) from None
h = hash((type(self), self._shape, bytes(self._data)))
self.__hash = h
return h
@overload
def __array__(
self,
dtype: None = None,
copy: bool | None= None,
/) -> npt.NDArray[np.uint8]: ...
@overload
def __array__(
self,
dtype: _ScalarType,
copy: bool | None= None,
/) -> npt.NDArray[_ScalarType]: ...
def __array__(
self,
dtype: _ScalarType | None = None,
copy: bool | None= None,
/) -> Any:
return self._data.__array__(dtype, copy)