Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,54 @@
"endColumn": 29,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 28,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 28,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 28,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 28,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 28,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 15,
"endColumn": 38,
"lineCount": 1
}
}
],
"./arraycontext/fake_numpy.py": [
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v6
- name: "Main Script"
run: |
EXTRA_INSTALL="pytest types-colorama types-Pygments"
EXTRA_INSTALL="pytest types-colorama types-Pygments scipy-stubs"
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0

Expand Down
4 changes: 4 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
from .context import (
ArrayContext,
ArrayContextFactory,
CSRMatrix,
SparseMatrix,
tag_axes,
)
from .impl.jax import EagerJAXArrayContext
Expand Down Expand Up @@ -129,6 +131,7 @@
"ArrayOrScalarT",
"ArrayT",
"BcastUntilActxArray",
"CSRMatrix",
"CommonSubexpressionTag",
"ContainerOrScalarT",
"EagerJAXArrayContext",
Expand All @@ -144,6 +147,7 @@
"ScalarLike",
"SerializationKey",
"SerializedContainer",
"SparseMatrix",
"dataclass_array_container",
"deserialize_container",
"flat_size_and_dtype",
Expand Down
200 changes: 199 additions & 1 deletion arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@

.. autoclass:: ArrayContext

.. autoclass:: SparseMatrix
.. autoclass:: CSRMatrix

.. autofunction:: tag_axes

.. class:: P
Expand Down Expand Up @@ -114,13 +117,15 @@
"""


import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Callable, Hashable, Mapping
from typing import (
TYPE_CHECKING,
Any,
ParamSpec,
TypeAlias,
cast,
overload,
)
from warnings import warn
Expand All @@ -129,21 +134,27 @@

from pytools import memoize_method

from arraycontext.container.traversal import (
rec_map_container,
)


if TYPE_CHECKING:
import numpy as np
from numpy.typing import DTypeLike

import loopy
from pytools.tag import ToTagSetConvertible
from pytools.tag import Tag, ToTagSetConvertible

from .fake_numpy import BaseFakeNumpyNamespace
from .typing import (
Array,
ArrayContainerT,
ArrayOrArithContainerOrScalarT,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ArrayOrScalar,
ContainerOrScalarT,
NumpyOrContainerOrScalar,
ScalarLike,
Expand All @@ -152,6 +163,26 @@

P = ParamSpec("P")

_EMPTY_TAG_SET: frozenset[Tag] = frozenset()


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class SparseMatrix(ABC):
shape: tuple[int, int]
tags: ToTagSetConvertible = dataclasses.field(kw_only=True)
axes: tuple[ToTagSetConvertible, ...] = dataclasses.field(kw_only=True)
_actx: ArrayContext = dataclasses.field(kw_only=True)

def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer:
return self._actx.sparse_matmul(self, other)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class CSRMatrix(SparseMatrix):
elem_values: Array
elem_col_indices: Array
row_starts: Array


# {{{ ArrayContext

Expand All @@ -169,6 +200,8 @@ class ArrayContext(ABC):
.. automethod:: to_numpy
.. automethod:: call_loopy
.. automethod:: einsum
.. automethod:: make_csr_matrix
.. automethod:: sparse_matmul
.. attribute:: np

Provides access to a namespace that serves as a work-alike to
Expand Down Expand Up @@ -421,6 +454,171 @@ def einsum(self,
)["out"]
return self.tag(tagged, out_ary)

def make_csr_matrix(
self,
shape: tuple[int, int],
elem_values: Array,
elem_col_indices: Array,
row_starts: Array,
*,
tags: ToTagSetConvertible = _EMPTY_TAG_SET,
axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix:
"""Return a sparse matrix in compressed sparse row (CSR) format, to be used
with :meth:`sparse_matmul`.

:arg shape: the (two-dimensional) shape of the matrix
:arg elem_values: a one-dimensional array containing the values of all of the
nonzero entries of the matrix, grouped by row.
:arg elem_col_indices: a one-dimensional array containing the column index
values corresponding to each entry in *elem_values*.
:arg row_starts: a one-dimensional array of length `nrows+1`, where each entry
gives the starting index in *elem_values* and *elem_col_indices* for the
given row, with the last entry being equal to `nrows`.
"""
if axes is None:
axes = (frozenset(), frozenset())

return CSRMatrix(
shape, elem_values, elem_col_indices, row_starts,
tags=tags, axes=axes,
_actx=self)

@memoize_method
def _get_csr_matmul_prg(self, out_ndim: int) -> loopy.TranslationUnit:
import numpy as np

import loopy as lp

out_extra_inames = tuple(f"i{n}" for n in range(1, out_ndim))
out_inames = ("irow", *out_extra_inames)
out_inames_set = frozenset(out_inames)

out_extra_shape_comp_names = tuple(f"n{n}" for n in range(1, out_ndim))
out_shape_comp_names = ("nrows", *out_extra_shape_comp_names)

domains: list[str] = []
domains.append(
"{ [" + ",".join(out_inames) + "] : "
+ " and ".join(
f"0 <= {iname} < {shape_comp_name}"
for iname, shape_comp_name in zip(
out_inames, out_shape_comp_names, strict=True))
+ " }")
domains.append(
"{ [iel] : iel_lbound <= iel < iel_ubound }")

temporary_variables: Mapping[str, lp.TemporaryVariable] = {
"iel_lbound": lp.TemporaryVariable(
"iel_lbound",
shape=(),
address_space=lp.AddressSpace.GLOBAL,
),
"iel_ubound": lp.TemporaryVariable(
"iel_ubound",
shape=(),
address_space=lp.AddressSpace.GLOBAL,
)}

from loopy.kernel.instruction import make_assignment
from pymbolic import var
instructions: list[lp.Assignment | lp.CallInstruction] = [
make_assignment(
(var("iel_lbound"),),
var("row_starts")[var("irow")],
id="insn0",
within_inames=out_inames_set),
make_assignment(
(var("iel_ubound"),),
var("row_starts")[var("irow") + 1],
id="insn1",
within_inames=out_inames_set),
make_assignment(
(var("out")[tuple(var(iname) for iname in out_inames)],),
lp.Reduction(
"sum",
(var("iel"),),
var("elem_values")[var("iel"),]
* var("array")[(
var("elem_col_indices")[var("iel"),],
*(var(iname) for iname in out_extra_inames))]),
id="insn2",
within_inames=out_inames_set,
depends_on=frozenset({"insn0", "insn1"}))]

from loopy.version import MOST_RECENT_LANGUAGE_VERSION

from .loopy import _DEFAULT_LOOPY_OPTIONS

knl = lp.make_kernel(
domains=domains,
instructions=instructions,
temporary_variables=temporary_variables,
kernel_data=[
lp.ValueArg("nrows", is_input=True),
lp.ValueArg("ncols", is_input=True),
lp.ValueArg("nels", is_input=True),
*(
lp.ValueArg(shape_comp_name, is_input=True)
for shape_comp_name in out_extra_shape_comp_names),
lp.GlobalArg("elem_values", shape=(var("nels"),), is_input=True),
lp.GlobalArg("elem_col_indices", shape=(var("nels"),), is_input=True),
lp.GlobalArg("row_starts", shape=lp.auto, is_input=True),
lp.GlobalArg(
"array",
shape=(
var("ncols"),
*(
var(shape_comp_name)
for shape_comp_name in out_extra_shape_comp_names),),
is_input=True),
lp.GlobalArg(
"out",
shape=(
var("nrows"),
*(
var(shape_comp_name)
for shape_comp_name in out_extra_shape_comp_names),),
is_input=False),
...],
name="csr_matmul_kernel",
lang_version=MOST_RECENT_LANGUAGE_VERSION,
options=_DEFAULT_LOOPY_OPTIONS,
default_order=lp.auto,
default_offset=lp.auto,
)

idx_dtype = knl.default_entrypoint.index_dtype

return lp.add_and_infer_dtypes(
knl,
{
",".join([
"ncols", "nrows", "nels",
*out_extra_shape_comp_names]): idx_dtype,
"elem_values,array,out": np.float64,
"elem_col_indices,row_starts": idx_dtype})

def sparse_matmul(
self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer:
"""Multiply a sparse matrix by an array.

:arg x1: the sparse matrix.
:arg x2: the array.
"""
if isinstance(x1, CSRMatrix):
def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar:
assert self.is_array_type(ary)
prg = self._get_csr_matmul_prg(len(ary.shape))
return self.call_loopy(
prg, elem_values=x1.elem_values,
elem_col_indices=x1.elem_col_indices,
row_starts=x1.row_starts, array=ary)["out"]

return cast("ArrayOrContainer", rec_map_container(_matmul, x2))

else:
raise TypeError(f"unrecognized sparse matrix type '{type(x1).__name__}'")

@abstractmethod
def clone(self) -> Self:
"""If possible, return a version of *self* that is semantically
Expand Down
Loading
Loading