GeneralConvΒΆ

General convolution over an input with several channels.

Abstract Signature:

GeneralConv(input: Tensor, weight: Tensor, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = False)

PyTorch

API: β€”
Strategy: Custom / Partial

JAX (Core)

API: jax.lax.conv_general_dilated
Strategy: Plugin (jax_conv_general_adapter)

Apple MLX

API: mlx.core.conv_general
Strategy: Direct Mapping

Flax NNX

API: jax.lax.conv_general_dilated
Strategy: Plugin (jax_conv_general_adapter)

PaxML / Praxis

API: jax.lax.conv_general_dilated
Strategy: Plugin (jax_conv_general_adapter)