Source code for q1ss.binalg.binvec

"""
Binary vectors.
"""

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.

# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from __future__ import annotations
from collections.abc import Iterable, Iterator, Sequence
from itertools import product
from numbers import Integral
from typing import TYPE_CHECKING, Any, Literal, cast, final, overload
from typing_extensions import Self
import numpy as np
import numpy.typing as npt
from typing_validation import validate
from .vectorized import (
    BinVec,
    bits_from_bytes,
    bytes_from_bits,
)
from .base import bintensor

if TYPE_CHECKING:
    from .binmat import binmat

Bit = Literal[0, 1]
"""
Type alias for a bit value.
"""


[docs] @final class binvec(bintensor): r""" A mutable binary vector. """
[docs] @staticmethod def validate_dim(dim: int, *, positive: bool = True) -> Literal[True]: """ Validate a :class:`binvec` dimension. Raises :obj:`TypeError` or :obj:`ValueError` for invalid dimensions, returns :obj:`True` otherwise. """ validate(dim, int) qual = "positive" if positive else "non-negative" if dim < 0 or (positive and dim == 0): raise ValueError(f"Dimension must be {qual}, found {dim}.") return True
[docs] @staticmethod def validate_bitstr(bits: str) -> Literal[True]: """ Validate a binary string. Raises :obj:`TypeError` or :obj:`ValueError` for invalid strings, returns :obj:`True` otherwise. """ validate(bits, str) if not all(b in "01" for b in bits): raise ValueError("Characters in bitstring must be '0' or '1'.") return True
[docs] @staticmethod def zeros(dim: int, *, readonly: bool = False) -> binvec: """ Constructs a zero binary vector with given dimension. """ assert binvec.validate_dim(dim, positive=False) return binvec(np.zeros(dim, dtype=np.uint8), readonly=readonly)
[docs] @staticmethod def random( dim: int, *, rng: np.random.Generator | int | None = None, readonly: bool = False, ) -> binvec: """ Random binary vector with given dimension. """ assert validate(dim, int) if not isinstance(rng, np.random.Generator): rng = np.random.default_rng(rng) bits = rng.integers(0, 2, (dim,), dtype=np.uint8) return binvec(bits, readonly=readonly)
[docs] @staticmethod def el(idx: int, dim: int, *, readonly: bool = False) -> binvec: """ Returns the canonical basis vector with given index in given dimension. """ assert validate(idx, int) assert binvec.validate_dim(dim) idx %= dim data = np.zeros(dim, dtype=np.uint8) data[idx] = 1 return binvec(data, readonly=readonly)
[docs] @staticmethod def iter_std_basis(dim: int, *, readonly: bool = False) -> Iterator[binvec]: """ Iterates through all standard basis binary vectors with given dimension. """ assert binvec.validate_dim(dim, positive=False) for idx in range(dim): data = np.zeros(dim, dtype=np.uint8) data[idx] = 1 yield binvec(data, readonly=readonly)
[docs] @staticmethod def iter_all(dim: int, *, readonly: bool = False) -> Iterator[binvec]: """ Iterates through all binary vectors with given dimension. """ assert binvec.validate_dim(dim, positive=False) if dim == 0: yield binvec.zeros(0) return for data in product([0, 1], repeat=dim): yield binvec(data, readonly=readonly)
[docs] @staticmethod def from_bool(bits: Iterable[Any], *, readonly: bool = False) -> binvec: """ Constructs a binary vector from an interable of boolean values. """ data = np.fromiter((1 if b else 0 for b in bits), dtype=np.uint8) return binvec(data, readonly=readonly)
[docs] @staticmethod def from_str(bits: str, *, readonly: bool = False) -> binvec: """ Constructs a binary vector from a string with chars ``'0'`` and ``'1'``. """ assert validate(bits, str) assert binvec.validate_bitstr(bits) data = np.array([int(b, 2) for b in bits], dtype=np.uint8) return binvec(data, readonly=readonly)
[docs] @staticmethod def from_int(bits: int, n: int) -> binvec: """ Constructs a binary vector from an integer. """ assert validate(bits, int) assert validate(n, int) bits %= 2**n if n == 0: return binvec([]) data = np.array( [int(b) for b in bin(bits)[2:].zfill(n)], dtype=np.uint8 ) return binvec(data)
[docs] @staticmethod def hstack(vecs: Sequence[binvec], *, readonly: bool = False) -> binvec: """ Stacks the given vectors horizontally. """ assert validate(vecs, Sequence[binvec]) return binvec(np.hstack([v._data for v in vecs]), readonly=readonly)
[docs] @staticmethod def from_bytes( b: bytes, num_bits: int | None = None, *, readonly: bool = False ) -> binvec: """ Converts bytes to a binary vector containing the corresponding bits. The binary vector has length ``8*len(b)`` by default, containing all bits, but length can be truncated by specifying a desired ``num_bits`` between ``len(b)-7`` and ``len(b)`` (both inclusive). If a length is specified, the bits ignored at the end must all be zero. """ if num_bits is None: num_bits = len(b) * 8 if len(b) * 8 - num_bits not in range(8): raise ValueError(f"Expected n in range({len(b)*8-7}, {len(b)*8+1})") b_array = np.fromiter(b, dtype=np.uint8) if num_bits % 8 != 0 and b_array[-1] & ( 2 ** (k := 8 - num_bits % 8) - 1 ): raise ValueError( f"Expected last {k} bits to be zero, " f"found {b_array[-1]&(2**k-1):0>{k}b}" ) return binvec(bits_from_bytes(b_array, num_bits), readonly=readonly)
@classmethod def _validate_data(cls, data: BinVec) -> Literal[True]: super()._validate_data(data) if len(data.shape) != 1: raise ValueError( f"Expected data to be 1D array, found shape {data.shape}." ) return True
[docs] def __new__( cls, data: npt.ArrayLike, *, readonly: bool = False, copy: bool = False ) -> Self: """ Creates a new vector from binary data. If ``readonly=True``, the resulting tensor and its data are readonly. If ``copy=True``, a fresh copy of the given data is used. .. warning:: The internal logic of :class:`binvec` presumes that the given data will not be mutated externally to the :class:`binvec` object after construction. If a fresh copy is needed, pass ``copy=True`` at construction. :meta public: """ if not isinstance(data, np.ndarray) or data.dtype != np.uint8: data = np.array(data, dtype=np.uint8) return super().__new__( cls, cast(BinVec, data), readonly=readonly, copy=copy )
@property def bin(self) -> str: """ Binary string representation of this binary vector. """ return "".join(str(b) for b in self)
[docs] def __bytes__(self) -> bytes: """ Converts this binary vector to bytes. :meta public: """ return bytes(bytes_from_bits(self._data))
@overload def __getitem__(self, idx: int) -> Bit: ... @overload def __getitem__(self, idx: slice | list[int]) -> binvec: ...
[docs] def __getitem__(self, idx: int | slice | list[int]) -> Bit | binvec: """ If the index is an integer, returns the corresponding entry of the vec. If the index is a slice or a list/array of integers, returns the vec containing the selected entries. :meta public: """ if isinstance(idx, int): return cast(Bit, int(self._data[idx])) assert validate(idx, slice | list[int]) return binvec(self._data[idx])
@overload def __setitem__(self, idx: int, value: int) -> None: ... @overload def __setitem__( self, idx: slice | list[int], value: int | binvec ) -> None: ...
[docs] def __setitem__( self, idx: int | slice | list[int], value: int | binvec ) -> None: """ Sets a single value, or a slice/selection of values. :meta public: """ if self._readonly: raise binvec.ReadonlyError("Tensor is read-only.") if isinstance(idx, int): assert validate(value, Integral) self._data[idx] = int(cast(Integral, value)) % 2 else: assert validate(idx, slice | list[int]) if isinstance(value, binvec): self._data[idx] = value._data else: self._data[idx] = int(value) % 2 self._postprocess_mutation()
[docs] def __or__(self, other: binvec) -> binvec: """ Horizontal stacking of two vectors. :meta public: """ if not isinstance(other, binvec): return NotImplemented return binvec.hstack([self, other])
@overload def __matmul__(self, other: binvec) -> Bit: ... @overload def __matmul__(self, other: binmat) -> binvec: ...
[docs] def __matmul__(self, other: binvec | binmat) -> Bit | binvec: r""" Binary vector-vector inner product or vector-matrix multiplication. :raises ShapeError: if the intermediate dimensions don't match. :meta public: """ from .binmat import binmat if not isinstance(other, (binvec, binmat)): return NotImplemented assert self.__has_compatible_matmul_shape(other) res = (self._data @ other._data) % 2 if isinstance(other, binmat): return binvec(cast(BinVec, res)) return cast(Bit, int(res))
[docs] def __imatmul__(self, other: binmat) -> binvec: # type: ignore[misc] r""" Inplace binary vector-matrix multiplication. :raises ShapeError: if the intermediate dimensions don't match. :meta public: """ if self._readonly: raise bintensor.ReadonlyError("Tensor is read-only.") from .binmat import binmat if not isinstance(other, binmat): return NotImplemented assert self.__has_compatible_matmul_shape(other) self._data = (self._data @ other._data) % 2 self._postprocess_mutation() return self
def __len__(self) -> int: """ The dimension of this binary vector (i.e. the number of bits). :meta public: """ return self._shape[0] def __iter__(self) -> Iterator[Bit]: """ Iterates over the bits in this binary vector. :meta public: """ for b in self._data: yield cast(Bit, int(b))
[docs] def __int__(self) -> int: """ Converts the binary vector to an integer. :meta public: """ n = len(self) if n == 0: return 0 return int( np.sum(self._data * 2 ** (n - 1 - np.arange(n, dtype=np.uint64))) )
def __has_compatible_matmul_shape( self, other: binvec | binmat ) -> Literal[True]: if self.shape[-1] != other._shape[0]: ss, os = self._shape, other._shape raise bintensor.ShapeError( f"unsupported operand shapes for @: {ss} and {os}" ) return True