JaxLayerΒΆ

Keras Layer that wraps a JAX model.

Abstract Signature:

JaxLayer(call_fn: callable)

Keras

API: keras.layers.JaxLayer
Strategy: Direct Mapping

TensorFlow

API: keras.layers.JaxLayer
Strategy: Direct Mapping