GeneralConvΒΆ
General convolution over an input with several channels.
Abstract Signature:
GeneralConv(input: Tensor, weight: Tensor, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = False)
PyTorch
API:
βStrategy: Custom / Partial
JAX (Core)
API:
jax.lax.conv_general_dilatedStrategy: Plugin (jax_conv_general_adapter)
Flax NNX
API:
jax.lax.conv_general_dilatedStrategy: Plugin (jax_conv_general_adapter)
PaxML / Praxis
API:
jax.lax.conv_general_dilatedStrategy: Plugin (jax_conv_general_adapter)