ivy.graph_transpile()
#
Ivy’s Graph Transpiler converts a function written in any framework into your framework of choice, preserving all the logic between frameworks. As the output of transpilation is native code in the target framework, it can be used as if it was originally developed in that framework, allowing you to apply and use framework-specific optimizations or tools.
This makes all ML-related projects available to you, independently of the framework you want to use to research, develop, or deploy systems. So if you want to:
Use functions and building blocks like neural networks, layers, activations, and training pipelines published in other frameworks. Ex: Using Haiku modules in PyTorch to get access to the latest model.
Integrate code developed in other frameworks into your code. Ex: Use the Kornia library in Jax for extra performance.
Take advantage of specific features in other frameworks. Ex: Convert Jax code to Tensorflow for deployment.
To convert the code, it traces a computational graph using the Tracer and leverages Ivy’s frontends and backends to link one framework to another. After swapping each function node in the computational graph with their equivalent Ivy frontend functions, the tracer removes all the wrapping in the frontends and replaces them with the native functions of the target framework.
Graph Transpiler API#
- ivy.graph_transpile(*objs, source=None, to=None, debug_mode=False, args=None, kwargs=None, params_v=None)#
Transpiles a
Callable
or set of them from asource
framework to another framework. Ifargs
orkwargs
are specified, transpilation is performed eagerly, otherwise, transpilation will happen lazily.- Parameters:
objs (
Callable
) – Native callable(s) to transpile.source (
Optional[str]
) – The framework thatobj
is from. This must be provided unlessobj
is a framework-specific module.to (
Optional[str]
) – The target framework to transpileobj
to.debug_mode (
bool
) – Whether to transpile to ivy first, before the final compilation to the target framework. If the target is ivy, then this flag makes no difference.args (
Optional[Tuple]
) – If specified, arguments that will be used to transpile eagerly.kwargs (
Optional[dict]
) – If specified, keyword arguments that will be used to transpile eagerly.params_v – Parameters of a haiku model, as when transpiling these, the parameters need to be passed explicitly to the function call.
- Return type:
Union[Graph, LazyGraph, ModuleType, ivy.Module, torch.nn.Module, tf.keras.Model, hk.Module]
- Returns:
A transpiled
Graph
or a non-initializedLazyGraph
. If the object is a native trainable module, the corresponding module in the target framework will be returned. If the object is aModuleType
, the function will return a copy of the module with every method lazily transpiled.
Using the transpiler#
Similar to the ivy.trace
function, ivy.graph_transpile
can be used
eagerly and lazily. If you pass the necessary arguments, the function will be called
instantly, otherwise, transpilation will happen the first time you invoke the function
with the proper arguments.
In both cases, arguments or keyword arguments can be arrays from
either the source
framework or the target (to
) framework.
Transpiling functions#
First, let’s start transpiling some simple functions. In the snippet below, we transpile a small JAX function to Torch both eagerly and lazily.
import ivy
ivy.set_backend("jax")
# Simple JAX function to transpile
def test_fn(x):
return jax.numpy.sum(x)
x1 = ivy.array([1., 2.])
# Arguments are available -> transpilation happens eagerly
eager_graph = ivy.graph_transpile(test_fn, source="jax", to="torch", args=(x1,))
# eager_graph is now torch code and runs efficiently
ret = eager_graph(x1)
# Arguments are not available -> transpilation happens lazily
lazy_graph = ivy.graph_transpile(test_fn, source="jax", to="torch")
# The transpiled graph is initialized, transpilation will happen here
ret = lazy_graph(x1)
# lazy_graph is now torch code and runs efficiently
ret = lazy_graph(x1)
Transpiling Libraries#
Likewise, you can use ivy.graph_transpile
to convert entire libraries and modules with just one line of
code!
import ivy
import kornia
import requests
import jax.numpy as jnp
from PIL import Image
# transpile kornia from torch to jax
jax_kornia = ivy.graph_transpile(kornia, source="torch", to="jax")
# get an image
url = "http://images.cocodataset.org/train2017/000000000034.jpg"
raw_img = Image.open(requests.get(url, stream=True).raw)
# convert it to the format expected by kornia
img = jnp.transpose(jnp.array(raw_img), (2, 0, 1))
img = jnp.expand_dims(img, 0) / 255
# and use the transpiled version of any function from the library!
out = jax_kornia.enhance.sharpness(img, 5)
Transpiling Modules#
Last but not least, Ivy can also transpile trainable modules from one framework to
another, at the moment we support torch.nn.Module
when to="torch"
,
tf.keras.Model
when to="tensorflow"
, and haiku models when to="jax"
.
import ivy
import timm
import torch
import jax
import haiku as hk
# Get a pretrained pytorch model
mlp_encoder = timm.create_model("mixer_b16_224", pretrained=True, num_classes=0)
# Transpile it into a hk.Module with the corresponding parameters
noise = torch.randn(1, 3, 224, 224)
mlp_encoder = ivy.graph_transpile(mlp_encoder, to="jax", args=(noise,))
# Build a classifier using the transpiled encoder
class Classifier(hk.Module):
def __init__(self, num_classes=1000):
super().__init__()
self.encoder = mlp_encoder()
self.fc = hk.Linear(output_size=num_classes, with_bias=True)
def __call__(self, x):
x = self.encoder(x)
x = self.fc(x)
return x
def _forward_classifier(x):
module = Classifier()
return module(x)
# Transform the classifier and use it as a standard hk.Module
rng_key = jax.random.PRNGKey(42)
x = jax.random.uniform(key=rng_key, shape=(1, 3, 224, 224), dtype=jax.numpy.float32)
forward_classifier = hk.transform(_forward_classifier)
params = forward_classifier.init(rng=rng_key, x=x)
ret = forward_classifier.apply(params, None, x)
Examples#
Here, we are transpiling a HF model from torch to tensorflow and then using the resulting model with tensorflow tensors directly:
import ivy
from transformers import AutoImageProcessor, ResNetForImageClassification
from datasets import load_dataset
# Set backend to torch
ivy.set_backend("torch")
# Download the input image
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
# Setting the model
image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
# Transpiling the model to tensorflow
tf_model = ivy.graph_transpile(model, source="torch", to="tensorflow", kwargs=inputs)
# Using the transpiled model
tf_inputs = image_processor(image, return_tensors="tf")
ret = tf_model(None, **tf_inputs)