Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@
ne,
neg,
pow,
quantile,
relu,
rms_norm,
rot90,
rotary_position_embedding,
rsqrt,
scaled_dot_product_attention,
select_copy,
sgn,
sigmoid,
sign,
signbit,
silu,
sin,
softmax,
Expand Down Expand Up @@ -71,12 +77,18 @@
"ne",
"neg",
"pow",
"quantile",
"relu",
"rms_norm",
"rot90",
"rotary_position_embedding",
"rsqrt",
"scaled_dot_product_attention",
"select_copy",
"sgn",
"sigmoid",
"sign",
"signbit",
"silu",
"sin",
"softmax",
Expand Down
122 changes: 122 additions & 0 deletions src/ntops/kernels/quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement(input, q, dim_size, output, dim, block_size=None):
def _arrange_input_or_output(tensor, dim):
ndim = tensor.ndim

if dim < 0:
dim += ndim

non_target_dims = tuple(i for i in range(ndim) if i != dim)

arranged = tensor.permute(non_target_dims + (dim,))

block_shape = tuple(1 for _ in non_target_dims) + (-1,)
non_target_dim_indices = tuple(range(len(non_target_dims)))

arranged = arranged.tile(block_shape)
arranged.dtype = arranged.dtype.squeeze(non_target_dim_indices)

return arranged

input_arranged = _arrange_input_or_output(input, dim)
output_arranged = _arrange_input_or_output(output, 0)

q_arranged = q.tile((-1,))
q_arranged = q_arranged.squeeze(0)

for _ in range(output_arranged.ndim):
q_arranged = q_arranged.unsqueeze(0)

q_arranged = q_arranged.expand(output_arranged.shape)

return input_arranged, q_arranged, dim_size, output_arranged


def linear_application(input, q, dim_size, output):
pos = ntl.cast(q * (dim_size - 1), ntl.float32)
i = ntl.cast(ntl.floor(pos), ntl.int32)
j = ntl.cast(ntl.ceil(pos), ntl.int32)
frac = pos - i

sorted = ntl.sort(input)
lower_value = ntl.gather(sorted, i, 0)
higher_value = ntl.gather(sorted, j, 0)

output = lower_value + frac * (higher_value - lower_value) # noqa: F841


def lower_application(input, q, dim_size, output):
pos = ntl.cast(q * (dim_size - 1), ntl.float32)
i = ntl.cast(ntl.floor(pos), ntl.int32)

sorted = ntl.sort(input)
lower_value = ntl.gather(sorted, i, 0)

output = lower_value # noqa: F841


def higher_application(input, q, dim_size, output):
pos = ntl.cast(q * (dim_size - 1), ntl.float32)
j = ntl.cast(ntl.ceil(pos), ntl.int32)

sorted = ntl.sort(input)
higher_value = ntl.gather(sorted, j, 0)

output = higher_value # noqa: F841


def nearest_application(input, q, dim_size, output):
pos = ntl.cast(q * (dim_size - 1), ntl.float32)

# Rounding mode for float to int conversion is always towards zero,
# we have to manually implement `rtne` (round to nearest, ties to even).
i = ntl.cast(ntl.floor(pos), ntl.int32)
frac = ntl.cast(pos - i, ntl.float32)
i = ntl.where(frac > 0.5, ntl.minimum(i + 1, dim_size - 1), i)
i = ntl.where((frac == 0.5) & (i % 2 == 1), ntl.minimum(i + 1, dim_size - 1), i)

sorted = ntl.sort(input)
output = ntl.gather(sorted, i, 0) # noqa: F841


def midpoint_application(input, q, dim_size, output):
pos = ntl.cast(q * (dim_size - 1), ntl.float32)
i = ntl.cast(ntl.floor(pos), ntl.int32)
j = ntl.cast(ntl.ceil(pos), ntl.int32)

sorted = ntl.sort(input)
lower_value = ntl.gather(sorted, i, 0)
higher_value = ntl.gather(sorted, j, 0)

output = (higher_value + lower_value) / 2 # noqa: F841


def premake(in_ndim, out_ndim, dim, interpolation, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(in_ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(1, dtype=dtype, shape_options={"constexpr": True}),
Tensor(0),
Tensor(out_ndim, dtype=dtype, shape_options={"constexpr": True}),
)

if interpolation == "linear":
application = linear_application
elif interpolation == "lower":
application = lower_application
elif interpolation == "higher":
application = higher_application
elif interpolation == "nearest":
application = nearest_application
elif interpolation == "midpoint":
application = midpoint_application
else:
raise ValueError(f"Unsupported interpolation method: {interpolation}")

return arrangement_, application, tensors
88 changes: 88 additions & 0 deletions src/ntops/kernels/rot90.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement(input, output, k, dims, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

ndim = input.ndim
dims = tuple(dim if dim >= 0 else dim + ndim for dim in dims)
non_target_dims = tuple(i for i in range(ndim) if i not in dims)

def _arrange_0(tensor):
arranged = tensor.flatten()
arranged = arranged.tile((block_size,))

return arranged

def _arrange_1_or_3(tensor, dims):
arranged = tensor.permute(non_target_dims + dims)
arranged = arranged.flatten(end_dim=-1)
arranged = arranged.tile((1, -1))
arranged.dtype = arranged.dtype.squeeze(0)

return arranged

def _arrange_2(tensor, dims):
arranged = tensor.permute(non_target_dims + dims)

if ndim == 2:
arranged = arranged.unsqueeze(0)

arranged = arranged.flatten(end_dim=-2)
arranged = arranged.tile((1, -1, -1))
arranged.dtype = arranged.dtype.squeeze(0)

return arranged

if k % 4 == 0:
input_arranged = _arrange_0(input)
output_arranged = _arrange_0(output)
elif k % 4 == 1:
input_arranged = _arrange_1_or_3(input, dims)
output_arranged = _arrange_1_or_3(output, tuple(reversed(dims)))
elif k % 4 == 3:
input_arranged = _arrange_1_or_3(input, tuple(reversed(dims)))
output_arranged = _arrange_1_or_3(output, dims)
else: # k % 4 == 2
input_arranged = _arrange_2(input, dims)
output_arranged = _arrange_2(output, dims)

return input_arranged, output_arranged


def application_0(input, output):
output = input # noqa: F841


def application_1_or_3(input, output):
if input.shape[0] == 1:
output = input # noqa: F841
else:
output = ntl.flip(input, 0) # noqa: F841


def application_2(input, output):
output = ntl.flip(ntl.flip(input, 0), 1) # noqa: F841


def premake(ndim, k, dims, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, k=k, dims=dims, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
)

if k % 4 == 0:
application = application_0
elif k % 4 == 2:
application = application_2
else: # k % 4 == 1 or 3
application = application_1_or_3

return arrangement_, application, tensors
50 changes: 50 additions & 0 deletions src/ntops/kernels/select_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement(input, index, output, dim, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

if output.ndim < 1:
output = output.unsqueeze(0)
else:
output = output.flatten()

output_arranged = output.tile((1,))
output_arranged.dtype = output_arranged.dtype.squeeze(0)

if input.ndim < 2:
input = input.unsqueeze(0)
else:
if dim < 0:
dim += input.ndim

non_target_dims = tuple(i for i in range(input.ndim) if i != dim)
input = input.permute(non_target_dims + (dim,))

input_arranged = input.flatten(end_dim=-1)
input_arranged = input_arranged.tile((1, -1))
input_arranged.dtype = input_arranged.dtype.squeeze(0)

return input_arranged, index, output_arranged


def application(input, index, output):
idx = ntl.cast(index, ntl.int32)
output = input[idx] # noqa: F841


def premake(in_ndim, out_ndim, dim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(in_ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(0, dtype=ninetoothed.int32),
Tensor(out_ndim, dtype=dtype),
)

return arrangement_, application, tensors
39 changes: 39 additions & 0 deletions src/ntops/kernels/sgn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement(input, output, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

def _arrange(input):
arranged = input.flatten(end_dim=-1)
arranged = arranged.tile((block_size, 1))
arranged = arranged.tile((1, -1))
arranged.dtype = arranged.dtype.squeeze(0)

return arranged

return _arrange(input), _arrange(output)


def application(input, output):
denominators = ntl.sqrt(input[0] * input[0] + input[1] * input[1])
denominators = ntl.where(denominators == 0.0, 1.0, denominators)

for i in range(input.shape[0]):
output[i] = input[i] / denominators # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
18 changes: 18 additions & 0 deletions src/ntops/kernels/sign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = ntl.where(input > 0, 1, ntl.where(input < 0, -1, 0)) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
26 changes: 26 additions & 0 deletions src/ntops/kernels/signbit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
if input.dtype is ntl.float16:
i_unint = ntl.cast(input, ntl.uint16, bitcast=True)
output = (i_unint >> 15) & 0x1 # noqa: F841
elif input.dtype is ntl.float32:
i_unint = ntl.cast(input, ntl.uint32, bitcast=True)
output = (i_unint >> 31) & 0x1 # noqa: F841
elif input.dtype is ntl.float64:
i_unint = ntl.cast(input, ntl.uint64, bitcast=True)
output = (i_unint >> 63) & 0x1 # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
Loading