Transpiling Models from PyTorch to TensorFlow#
You can install the dependencies required for this notebook by running the cell below ⬇️, or check out the Get Started section of the docs to find out more about installing ivy.
[ ]:
!pip install ivy
!pip install torch
!pip install tensorflow
Here we’ll go through an example of how any model written in PyTorch can be converted, and used in, TensorFlow via ivy.transpile
. First, lets import a simple torch model.
[ ]:
from example_models import SimpleModel
"""
This model is defined as follows:
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 3, kernel_size=3)
self.relu = torch.nn.ReLU()
self.fc = torch.nn.Linear(3 * 26 * 26, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
"""
Next, we can convert the model to tensorflow
[ ]:
import ivy
TFSimpleModel = ivy.transpile(SimpleModel, source="torch", target="tensorflow")
Now we can use the model with TensorFlow
[9]:
import tensorflow as tf
tf_model = TFSimpleModel()
tf_model(tf.random.normal((1, 1, 28, 28))).shape
[9]:
TensorShape([1, 10])
We can also take advantage of TensorFlow-specific features, such as tf.function
:
[10]:
compiled_model = tf.function(tf_model)
compiled_model(tf.random.normal((1, 1, 28, 28))).shape
[10]:
TensorShape([1, 10])