Source code for opennmt.utils.exporters

"""Define model exporters."""

import abc
import os
import tempfile

import ctranslate2
import tensorflow as tf

from opennmt.utils import misc


[docs]class Exporter(abc.ABC): """Base class for model exporters."""
[docs] def export(self, model, export_dir): """Exports :obj:`model` to :obj:`export_dir`. Raises: ValueError: if :obj:`model` is not supported by this exporter. """ self._export_model(model, export_dir) with tempfile.TemporaryDirectory() as tmp_dir: extra_assets = model.export_assets(tmp_dir) if extra_assets: assets_extra = os.path.join(export_dir, "assets.extra") tf.io.gfile.makedirs(assets_extra) for filename, path in extra_assets.items(): tf.io.gfile.copy( path, os.path.join(assets_extra, filename), overwrite=True ) tf.get_logger().info("Extra assets written to: %s", assets_extra)
@abc.abstractmethod def _export_model(self, model, export_dir): raise NotImplementedError()
_EXPORTERS_REGISTRY = misc.ClassRegistry(base_class=Exporter) register_exporter = _EXPORTERS_REGISTRY.register def make_exporter(name, **kwargs): """Creates a new exporter. Args: name: The exporter name. **kwargs: Additional arguments to pass to the exporter constructor. Returns: A :class:`opennmt.utils.Exporter` instance. Raises: ValueError: if :obj:`name` is invalid. """ exporter_class = _EXPORTERS_REGISTRY.get(name) if exporter_class is None: raise ValueError("Invalid exporter name: %s" % name) return exporter_class(**kwargs) def list_exporters(): """Lists the name of registered exporters.""" return _EXPORTERS_REGISTRY.class_names
[docs]@register_exporter(name="saved_model") class SavedModelExporter(Exporter): """SavedModel exporter.""" def _export_model(self, model, export_dir): tf.saved_model.save(model, export_dir, signatures=model.serve_function())
[docs]@register_exporter(name="checkpoint") class CheckpointExporter(Exporter): """Checkpoint exporter.""" def _export_model(self, model, export_dir): checkpoint = tf.train.Checkpoint(model=model) checkpoint.write(os.path.join(export_dir, "ckpt"))
[docs]@register_exporter(name="tflite") class TFLiteExporter(Exporter): """TensorFlow Lite exporter.""" def __init__(self, quantization=None): accepted_quantization = ( "float16", "dynamic_range", ) if quantization is not None and quantization not in accepted_quantization: raise ValueError( "Unsupported quantization '%s' for TensorFlow Lite, accepted values are: %s" % (quantization, ", ".join(accepted_quantization)) ) self._quantization = quantization def _export_model(self, model, export_dir): # Tries to run prediction with TensorFlow Lite method it will convert tflite_concrete_fn = model.tflite_function().get_concrete_function() # Saving converter = tf.lite.TFLiteConverter.from_concrete_functions( [tflite_concrete_fn] ) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops. ] if self._quantization == "dynamic_range": converter.optimizations = [tf.lite.Optimize.DEFAULT] elif self._quantization == "float16": converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_types = [tf.float16] tflite_model_path = os.path.join(export_dir, "opennmt.tflite") tflite_model = converter.convert() with tf.io.gfile.GFile(tflite_model_path, "wb") as f: f.write(tflite_model)
[docs]@register_exporter(name="tflite_float16") class TFLiteFloat16Exporter(TFLiteExporter): """TensorFlow Lite exporter with float16 quantization.""" def __init__(self): super().__init__(quantization="float16")
@register_exporter(name="tflite_dynamic_range") class TFLiteDynamicRangeExporter(TFLiteExporter): """TensorFlow Lite exporter with dynamic range quantization.""" def __init__(self): super().__init__(quantization="dynamic_range")
[docs]@register_exporter(name="ctranslate2") class CTranslate2Exporter(Exporter): """CTranslate2 exporter."""
[docs] def __init__(self, quantization=None): """Initializes the exporter. Args: quantization: Quantize model weights to this type when exporting the model. Can be "int8", "int16", "float16", or "int8_float16". Default is no quantization. Raises: ImportError: if the CTranslate2 package is missing. ValueError: if :obj:`quantization` is invalid. """ accepted_quantization = ("int8", "int16", "float16", "int8_float16") if quantization is not None and quantization not in accepted_quantization: raise ValueError( "Invalid quantization '%s' for CTranslate2, accepted values are: %s" % (quantization, ", ".join(accepted_quantization)) ) self._quantization = quantization
def _export_model(self, model, export_dir): if not model.built: model.create_variables() converter = ctranslate2.converters.OpenNMTTFConverter(model) converter.convert(export_dir, quantization=self._quantization, force=True)
[docs]@register_exporter(name="ctranslate2_int8") class CTranslate2Int8Exporter(CTranslate2Exporter): """CTranslate2 exporter with int8 quantization.""" def __init__(self): super().__init__(quantization="int8")
[docs]@register_exporter(name="ctranslate2_int16") class CTranslate2Int16Exporter(CTranslate2Exporter): """CTranslate2 exporter with int16 quantization.""" def __init__(self): super().__init__(quantization="int16")
[docs]@register_exporter(name="ctranslate2_float16") class CTranslate2Float16Exporter(CTranslate2Exporter): """CTranslate2 exporter with float16 quantization.""" def __init__(self): super().__init__(quantization="float16")
[docs]@register_exporter(name="ctranslate2_int8_float16") class CTranslate2Int8Float16Exporter(CTranslate2Exporter): """CTranslate2 exporter with int8_float16 quantization.""" def __init__(self): super().__init__(quantization="int8_float16")