Source code for opennmt.inputters.inputter

"""Define generic inputters."""

import abc

import tensorflow as tf

from opennmt.data import dataset as dataset_util
from opennmt.layers import common
from opennmt.layers.reducer import ConcatReducer, JoinReducer
from opennmt.utils import misc


[docs]class Inputter(tf.keras.layers.Layer): """Base class for inputters.""" def __init__(self, **kwargs): super().__init__(**kwargs) self._asset_prefix = "" @property def asset_prefix(self): r"""The asset prefix is used to differentiate resources of parallel inputters. The most basic examples are the "source\_" and "target\_" prefixes. - When reading the data configuration, the inputter will read fields that start with this prefix (e.g. "source_vocabulary"). - Assets exported by this inputter start with this prefix. """ return self._asset_prefix @asset_prefix.setter def asset_prefix(self, asset_prefix): """Sets the asset prefix for this inputter.""" self._asset_prefix = asset_prefix @property def num_outputs(self): """The number of parallel outputs produced by this inputter.""" return 1
[docs] def initialize(self, data_config): """Initializes the inputter. Args: data_config: A dictionary containing the data configuration set by the user. """ _ = data_config return
[docs] def export_assets(self, asset_dir): """Exports assets used by this tokenizer. Args: asset_dir: The directory where assets can be written. Returns: A dictionary containing additional assets used by the inputter. """ _ = asset_dir return {}
[docs] @abc.abstractmethod def make_dataset(self, data_file, training=None): """Creates the base dataset required by this inputter. Args: data_file: The data file. training: Run in training mode. Returns: A ``tf.data.Dataset`` instance or a list of ``tf.data.Dataset`` instances. """ raise NotImplementedError()
[docs] def get_dataset_size(self, data_file): """Returns the dataset size. If the inputter can efficiently compute the dataset size from a training file on disk, it can optionally override this method. Otherwise, we may compute the size later with a generic and slower approach (iterating over the dataset instance). Args: data_file: The data file. Returns: The dataset size or ``None``. """ _ = data_file return None
[docs] def make_inference_dataset( self, features_file, batch_size, batch_type="examples", length_bucket_width=None, num_threads=1, prefetch_buffer_size=None, ): """Builds a dataset to be used for inference. For evaluation and training datasets, see :class:`opennmt.inputters.ExampleInputter`. Args: features_file: The test file. batch_size: The batch size to use. batch_type: The batching strategy to use: can be "examples" or "tokens". length_bucket_width: The width of the length buckets to select batch candidates from (for efficiency). Set ``None`` to not constrain batch formation. num_threads: The number of elements processed in parallel. prefetch_buffer_size: The number of batches to prefetch asynchronously. If ``None``, use an automatically tuned value. Returns: A ``tf.data.Dataset``. See Also: :func:`opennmt.data.inference_pipeline` """ transform_fns = _get_dataset_transforms( self, num_threads=num_threads, training=False ) dataset = self.make_dataset(features_file, training=False) dataset = dataset.apply( dataset_util.inference_pipeline( batch_size, batch_type=batch_type, transform_fns=transform_fns, length_bucket_width=length_bucket_width, length_fn=self.get_length, num_threads=num_threads, prefetch_buffer_size=prefetch_buffer_size, ) ) return dataset
[docs] @abc.abstractmethod def input_signature(self): """Returns the input signature of this inputter.""" raise NotImplementedError()
[docs] def get_length(self, features, ignore_special_tokens=False): """Returns the length of the input features, if defined. Args: features: The dictionary of input features. ignore_special_tokens: Ignore special tokens that were added by the inputter (e.g. <s> and/or </s>). Returns: The length. """ _ = ignore_special_tokens return features.get("length")
[docs] def get_padded_shapes(self, element_spec, maximum_length=None): """Returns the padded shapes for dataset elements. For example, this is used during batch size autotuning to pad all batches to the maximum sequence length. Args: element_spec: A nested structure of ``tf.TensorSpec``. maximum_length: Pad batches to this maximum length. Returns: A nested structure of ``tf.TensorShape``. """ return tf.nest.map_structure( lambda spec: spec.shape if spec.shape.rank == 0 else tf.TensorShape([maximum_length]).concatenate(spec.shape[1:]), element_spec, )
[docs] def has_prepare_step(self): """Returns ``True`` if this inputter implements a data preparation step in method :meth:`opennmt.inputters.Inputter.prepare_elements`. """ return False
[docs] def prepare_elements(self, elements, training=None): """Prepares dataset elements. This method is called on a batch of dataset elements. For example, it can be overriden to apply an external pre-tokenization. Note that the results of the method are unbatched and then passed to method :meth:`opennmt.inputters.Inputter.make_features`. Args: elements: A batch of dataset elements. training: Run in training mode. Returns: A (possibly nested) structure of ``tf.Tensor``. """ return elements
[docs] @abc.abstractmethod def make_features(self, element=None, features=None, training=None): """Creates features from data. This is typically called in a data pipeline (such as ``Dataset.map``). Common transformation includes tokenization, parsing, vocabulary lookup, etc. This method accepts both a single :obj:`element` from the dataset or a partially built dictionary of :obj:`features`. Args: element: An element from the dataset returned by :meth:`opennmt.inputters.Inputter.make_dataset`. features: An optional and possibly partial dictionary of features to augment. training: Run in training mode. Returns: A dictionary of ``tf.Tensor``. """ raise NotImplementedError()
[docs] def keep_for_training(self, features, maximum_length=None): """Returns ``True`` if this example should be kept for training. Args: features: A dictionary of ``tf.Tensor``. maximum_length: The maximum length used for training. Returns: A boolean. """ length = self.get_length(features) if length is None: return True is_valid = tf.greater(length, 0) if maximum_length is not None: is_valid = tf.logical_and(is_valid, tf.less_equal(length, maximum_length)) return is_valid
[docs] def call(self, features, training=None): """Creates the model input from the features (e.g. word embeddings). Args: features: A dictionary of ``tf.Tensor``, the output of :meth:`opennmt.inputters.Inputter.make_features`. training: Run in training mode. Returns: The model input. """ _ = training return features
[docs] def visualize(self, model_root, log_dir): """Visualizes the transformation, usually embeddings. Args: model_root: The root model object. log_dir: The active log directory. """ _ = model_root _ = log_dir return
[docs]class MultiInputter(Inputter): """An inputter that gathers multiple inputters, possibly nested.""" def __init__(self, inputters, reducer=None): if not isinstance(inputters, list) or not inputters: raise ValueError("inputters must be a non empty list") dtype = inputters[0].dtype for inputter in inputters: if inputter.dtype != dtype: raise TypeError("All inputters must have the same dtype") super().__init__(dtype=dtype) self.inputters = inputters self.reducer = reducer self.asset_prefix = "" # Generate the default prefix for sub-inputters. @Inputter.asset_prefix.setter def asset_prefix(self, asset_prefix): self._asset_prefix = asset_prefix for i, inputter in enumerate(self.inputters): inputter.asset_prefix = "%s%d_" % (asset_prefix, i + 1) @property def num_outputs(self): if self.reducer is None or isinstance(self.reducer, JoinReducer): return len(self.inputters) return 1
[docs] def get_leaf_inputters(self): """Returns a list of all leaf Inputter instances.""" inputters = [] for inputter in self.inputters: if isinstance(inputter, MultiInputter): inputters.extend(inputter.get_leaf_inputters()) else: inputters.append(inputter) return inputters
def __getattribute__(self, name): if name == "built": return all(inputter.built for inputter in self.inputters) else: return super().__getattribute__(name)
[docs] def initialize(self, data_config): for inputter in self.inputters: inputter.initialize( misc.RelativeConfig( data_config, inputter.asset_prefix, config_name="data" ) )
[docs] def export_assets(self, asset_dir): assets = {} for inputter in self.inputters: assets.update(inputter.export_assets(asset_dir)) return assets
[docs] def has_prepare_step(self): return any(inputter.has_prepare_step() for inputter in self.inputters)
[docs] def prepare_elements(self, elements, training=None): return tuple( inputter.prepare_elements(elts) for inputter, elts in zip(self.inputters, elements) )
[docs] def visualize(self, model_root, log_dir): for inputter in self.inputters: inputter.visualize(model_root, log_dir)
[docs]class ParallelInputter(MultiInputter): """A multi inputter that processes parallel data."""
[docs] def __init__( self, inputters, reducer=None, share_parameters=False, combine_features=True ): """Initializes a parallel inputter. Args: inputters: A list of :class:`opennmt.inputters.Inputter`. reducer: A :class:`opennmt.layers.Reducer` to merge all inputs. If set, parallel inputs are assumed to have the same length. share_parameters: Share the inputters parameters. combine_features: Combine each inputter features in a single dict or return them separately. This is typically ``True`` for multi source inputs but ``False`` for features/labels parallel data. """ super().__init__(inputters, reducer=reducer) self.combine_features = combine_features self.share_parameters = share_parameters
def _structure(self): """Returns the nested structure that represents this parallel inputter.""" return [ inputter._structure() if isinstance(inputter, ParallelInputter) else None for inputter in self.inputters ]
[docs] def make_dataset(self, data_file, training=None): if not isinstance(data_file, list): data_file = [data_file] # For evaluation and inference, accept a flat list of data files for nested inputters. # This is needed when nesting can't easily be represented (e.g. on the command line). if not training: try: data_file = tf.nest.pack_sequence_as( self._structure(), tf.nest.flatten(data_file) ) except ValueError: data_file = [] # This will raise the error below. if len(data_file) != len(self.inputters): raise ValueError( "The number of data files must be the same as the number of inputters" ) num_files = -1 datasets = [] for i, (inputter, data) in enumerate(zip(self.inputters, data_file)): dataset = inputter.make_dataset(data, training=training) if not isinstance(dataset, list): dataset = [dataset] datasets.append(dataset) if num_files < 0: num_files = len(dataset) elif len(dataset) != num_files: raise ValueError( "All parallel inputs must have the same number of data files, " "saw %d files for input 0 but got %d files for input %d" % (num_files, len(dataset), i) ) parallel_datasets = [ tf.data.Dataset.zip(tuple(parallel_dataset)) for parallel_dataset in zip(*datasets) ] if len(parallel_datasets) == 1: return parallel_datasets[0] if not training: raise ValueError("Only training data can be configured to multiple files") return parallel_datasets
[docs] def get_dataset_size(self, data_file): common_size = None for inputter, data in zip(self.inputters, data_file): size = inputter.get_dataset_size(data) if size is not None: if common_size is None: common_size = size elif size != common_size: raise RuntimeError("Parallel datasets do not have the same size") return common_size
[docs] def input_signature(self): if self.combine_features: signature = {} for i, inputter in enumerate(self.inputters): for key, value in inputter.input_signature().items(): signature["{}_{}".format(key, i)] = value return signature else: return tuple(inputter.input_signature() for inputter in self.inputters)
def _index_features(self, features, index): if self.combine_features: return misc.extract_prefixed_keys(features, "inputter_{}_".format(index)) else: return features[index]
[docs] def get_length(self, features, ignore_special_tokens=False): lengths = [ inputter.get_length( self._index_features(features, i), ignore_special_tokens=ignore_special_tokens, ) for i, inputter in enumerate(self.inputters) ] if self.reducer is None: return lengths else: return lengths[0]
[docs] def get_padded_shapes(self, element_spec, maximum_length=None): if maximum_length is None: maximum_length = [None for _ in self.inputters] elif not isinstance(maximum_length, (list, tuple)) or len( maximum_length ) != len(self.inputters): raise ValueError( "A maximum length should be set for each parallel inputter" ) if self.combine_features: shapes = {} for i, (inputter, length) in enumerate(zip(self.inputters, maximum_length)): prefix = "inputter_%d_" % i spec = misc.extract_prefixed_keys(element_spec, prefix) sub_shapes = inputter.get_padded_shapes(spec, maximum_length=length) for key, value in sub_shapes.items(): shapes["%s%s" % (prefix, key)] = value return shapes else: return type(element_spec)( inputter.get_padded_shapes(spec, maximum_length=length) for inputter, spec, length in zip( self.inputters, element_spec, maximum_length ) )
[docs] def make_features(self, element=None, features=None, training=None): if self.combine_features: if features is None: features = {} for i, inputter in enumerate(self.inputters): prefix = "inputter_%d_" % i sub_features = misc.extract_prefixed_keys(features, prefix) if not sub_features: # Also try to read the format produced by the serving features. sub_features = misc.extract_suffixed_keys(features, "_%d" % i) sub_features = inputter.make_features( element=element[i] if element is not None else None, features=sub_features, training=training, ) for key, value in sub_features.items(): features["%s%s" % (prefix, key)] = value return features else: if features is None: features = [{} for _ in self.inputters] else: features = list(features) for i, inputter in enumerate(self.inputters): features[i] = inputter.make_features( element=element[i] if element is not None else None, features=features[i], training=training, ) return tuple(features)
[docs] def keep_for_training(self, features, maximum_length=None): if not isinstance(maximum_length, list): maximum_length = [maximum_length] # Unset maximum lengths are set to None (i.e. no constraint). maximum_length += [None] * (len(self.inputters) - len(maximum_length)) constraints = [] for i, inputter in enumerate(self.inputters): keep = inputter.keep_for_training( self._index_features(features, i), maximum_length=maximum_length[i] ) if isinstance(keep, bool): if not keep: return False continue constraints.append(keep) if not constraints: return True return tf.reduce_all(constraints)
[docs] def build(self, input_shape): if self.share_parameters: # When sharing parameters, build the first leaf inputter and then set # all attributes with parameters to the other inputters. leaves = self.get_leaf_inputters() first, others = leaves[0], leaves[1:] for inputter in others: if type(inputter) is not type(first): # noqa: E721 raise ValueError( "Each inputter must be of the same type for parameter sharing" ) first.build(input_shape) for name, attr in first.__dict__.copy().items(): if isinstance(attr, tf.Variable) or ( isinstance(attr, tf.Module) and attr.variables ): for inputter in others: setattr(inputter, name, attr) inputter.built = True else: for inputter in self.inputters: inputter.build(input_shape) super().build(input_shape)
[docs] def call(self, features, training=None): transformed = [ inputter(self._index_features(features, i), training=training) for i, inputter in enumerate(self.inputters) ] if self.reducer is not None: transformed = self.reducer(transformed) return transformed
[docs]class MixedInputter(MultiInputter): """An multi inputter that applies several transformation on the same data (e.g. combine word-level and character-level embeddings). """
[docs] def __init__(self, inputters, reducer=ConcatReducer(), dropout=0.0): """Initializes a mixed inputter. Args: inputters: A list of :class:`opennmt.inputters.Inputter`. reducer: A :class:`opennmt.layers.Reducer` to merge all inputs. dropout: The probability to drop units in the merged inputs. """ super().__init__(inputters, reducer=reducer) self.dropout = dropout
[docs] def make_dataset(self, data_file, training=None): datasets = [ inputter.make_dataset(data_file, training=training) for inputter in self.inputters ] for dataset in datasets[1:]: if not isinstance(dataset, datasets[0].__class__): raise ValueError( "All inputters should use the same dataset in a MixedInputter setting" ) return datasets[0]
[docs] def get_dataset_size(self, data_file): for inputter in self.inputters: size = inputter.get_dataset_size(data_file) if size is not None: return size return None
[docs] def input_signature(self): signature = {} for inputter in self.inputters: signature.update(inputter.input_signature()) return signature
[docs] def get_length(self, features, ignore_special_tokens=False): return self.inputters[0].get_length( features, ignore_special_tokens=ignore_special_tokens )
[docs] def make_features(self, element=None, features=None, training=None): if features is None: features = {} for inputter in self.inputters: features = inputter.make_features( element=element, features=features, training=training ) return features
[docs] def build(self, input_shape): for inputter in self.inputters: inputter.build(input_shape) super().build(input_shape)
[docs] def call(self, features, training=None): transformed = [] for inputter in self.inputters: transformed.append(inputter(features, training=training)) outputs = self.reducer(transformed) outputs = common.dropout(outputs, self.dropout, training=training) return outputs
[docs]class ExampleInputterAdapter: """Extends an inputter with methods to build evaluation and training datasets."""
[docs] def make_evaluation_dataset( self, features_file, labels_file, batch_size, batch_type="examples", length_bucket_width=None, num_threads=1, prefetch_buffer_size=None, ): """Builds a dataset to be used for evaluation. Args: features_file: The evaluation source file. labels_file: The evaluation target file. batch_size: The batch size to use. batch_type: The batching strategy to use: can be "examples" or "tokens". length_bucket_width: The width of the length buckets to select batch candidates from (for efficiency). Set ``None`` to not constrain batch formation. num_threads: The number of elements processed in parallel. prefetch_buffer_size: The number of batches to prefetch asynchronously. If ``None``, use an automatically tuned value. Returns: A ``tf.data.Dataset``. See Also: :func:`opennmt.data.inference_pipeline` """ if labels_file is not None: data_files = [features_file, labels_file] length_fn = [ self.features_inputter.get_length, self.labels_inputter.get_length, ] else: data_files = features_file length_fn = self.get_length transform_fns = _get_dataset_transforms( self, num_threads=num_threads, training=False ) dataset = self.make_dataset(data_files, training=False) dataset = dataset.apply( dataset_util.inference_pipeline( batch_size, batch_type=batch_type, transform_fns=transform_fns, length_bucket_width=length_bucket_width, length_fn=length_fn, num_threads=num_threads, prefetch_buffer_size=prefetch_buffer_size, ) ) return dataset
[docs] def make_training_dataset( self, features_file, labels_file, batch_size, batch_type="examples", batch_multiplier=1, batch_size_multiple=1, shuffle_buffer_size=None, length_bucket_width=None, pad_to_bucket_boundary=False, maximum_features_length=None, maximum_labels_length=None, single_pass=False, num_shards=1, shard_index=0, num_threads=4, prefetch_buffer_size=None, cardinality_multiple=1, weights=None, batch_autotune_mode=False, ): """Builds a dataset to be used for training. It supports the full training pipeline, including: * sharding * shuffling * filtering * bucketing * prefetching Args: features_file: The source file or a list of training source files. labels_file: The target file or a list of training target files. batch_size: The batch size to use. batch_type: The training batching strategy to use: can be "examples" or "tokens". batch_multiplier: The batch size multiplier to prepare splitting accross replicated graph parts. batch_size_multiple: When :obj:`batch_type` is "tokens", ensure that the resulting batch size is a multiple of this value. shuffle_buffer_size: The number of elements from which to sample. length_bucket_width: The width of the length buckets to select batch candidates from (for efficiency). Set ``None`` to not constrain batch formation. pad_to_bucket_boundary: Pad each batch to the length bucket boundary. maximum_features_length: The maximum length or list of maximum lengths of the features sequence(s). ``None`` to not constrain the length. maximum_labels_length: The maximum length of the labels sequence. ``None`` to not constrain the length. single_pass: If ``True``, makes a single pass over the training data. num_shards: The number of data shards (usually the number of workers in a distributed setting). shard_index: The shard index this data pipeline should read from. num_threads: The number of elements processed in parallel. prefetch_buffer_size: The number of batches to prefetch asynchronously. If ``None``, use an automatically tuned value. cardinality_multiple: Ensure that the dataset cardinality is a multiple of this value when :obj:`single_pass` is ``True``. weights: An optional list of weights to create a weighted dataset out of multiple training files. batch_autotune_mode: When enabled, all batches are padded to the maximum sequence length. Returns: A ``tf.data.Dataset``. See Also: :func:`opennmt.data.training_pipeline` """ if labels_file is not None: data_files = [features_file, labels_file] maximum_length = [maximum_features_length, maximum_labels_length] features_length_fn = self.features_inputter.get_length labels_length_fn = self.labels_inputter.get_length else: data_files = features_file maximum_length = maximum_features_length features_length_fn = self.get_length labels_length_fn = None dataset = self.make_dataset(data_files, training=True) filter_fn = lambda *arg: ( self.keep_for_training( misc.item_or_tuple(arg), maximum_length=maximum_length ) ) transform_fns = _get_dataset_transforms( self, num_threads=num_threads, training=True ) transform_fns.append(lambda dataset: dataset.filter(filter_fn)) if batch_autotune_mode: # In this mode we want to return batches where all sequences are padded # to the maximum possible length in order to maximize the memory usage. # Shuffling, sharding, prefetching, etc. are not applied since correctness and # performance are not important. if isinstance(dataset, list): # Ignore weighted dataset. dataset = dataset[0] # We repeat the dataset now to ensure full batches are always returned. dataset = dataset.repeat() for transform_fn in transform_fns: dataset = dataset.apply(transform_fn) # length_fn returns the maximum length instead of the actual example length so # that batches are built as if each example has the maximum length. if labels_file is not None: constant_length_fn = [ lambda x: maximum_features_length, lambda x: maximum_labels_length, ] else: constant_length_fn = lambda x: maximum_features_length # The length dimension is set to the maximum length in the padded shapes. padded_shapes = self.get_padded_shapes( dataset.element_spec, maximum_length=maximum_length ) # Dynamically pad each sequence to the maximum length. def _pad_to_shape(tensor, padded_shape): if tensor.shape.rank == 0: return tensor tensor_shape = misc.shape_list(tensor) paddings = [ [0, padded_dim - tensor_dim] if tf.is_tensor(tensor_dim) and padded_dim is not None else [0, 0] for tensor_dim, padded_dim in zip(tensor_shape, padded_shape) ] return tf.pad(tensor, paddings) dataset = dataset.map( lambda *arg: tf.nest.map_structure( _pad_to_shape, misc.item_or_tuple(arg), padded_shapes ) ) dataset = dataset.apply( dataset_util.batch_sequence_dataset( batch_size, batch_type=batch_type, batch_multiplier=batch_multiplier, length_bucket_width=1, length_fn=constant_length_fn, ) ) return dataset if weights is not None: dataset = (dataset, weights) dataset = dataset_util.training_pipeline( batch_size, batch_type=batch_type, batch_multiplier=batch_multiplier, batch_size_multiple=batch_size_multiple, transform_fns=transform_fns, length_bucket_width=length_bucket_width, pad_to_bucket_boundary=pad_to_bucket_boundary, features_length_fn=features_length_fn, labels_length_fn=labels_length_fn, maximum_features_length=maximum_features_length, maximum_labels_length=maximum_labels_length, single_pass=single_pass, num_shards=num_shards, shard_index=shard_index, num_threads=num_threads, dataset_size=self.get_dataset_size(data_files), shuffle_buffer_size=shuffle_buffer_size, prefetch_buffer_size=prefetch_buffer_size, cardinality_multiple=cardinality_multiple, )(dataset) return dataset
def _register_example_weight(features, labels, weight): labels["weight"] = tf.strings.to_number(weight) return features, labels
[docs]class ExampleInputter(ParallelInputter, ExampleInputterAdapter): """An inputter that returns training examples (parallel features and labels)."""
[docs] def __init__( self, features_inputter, labels_inputter, share_parameters=False, accepted_annotations=None, ): """Initializes this inputter. Args: features_inputter: An inputter producing the features (source). labels_inputter: An inputter producing the labels (target). share_parameters: Share the inputters parameters. accepted_annotations: An optional dictionary mapping annotation names in the data configuration (e.g. "train_alignments") to a callable with signature ``(features, labels, annotations) -> (features, labels)``. """ self.features_inputter = features_inputter self.labels_inputter = labels_inputter super().__init__( [self.features_inputter, self.labels_inputter], share_parameters=share_parameters, combine_features=False, ) # Set a meaningful prefix for source and target. self.features_inputter.asset_prefix = "source_" self.labels_inputter.asset_prefix = "target_" self.accepted_annotations = accepted_annotations or {} self.accepted_annotations["example_weights"] = _register_example_weight self.annotation_files = {}
[docs] def initialize(self, data_config): super().initialize(data_config) # Check if some accepted annotations are defined in the data configuration. for annotation in self.accepted_annotations.keys(): path = data_config.get(annotation) if path is not None: self.annotation_files[annotation] = path
[docs] def make_dataset(self, data_file, training=None): dataset = super().make_dataset(data_file, training=training) if not training or not self.annotation_files: return dataset # Some annotations are configured and should be zipped to the training dataset. all_annotation_datasets = tf.nest.map_structure( tf.data.TextLineDataset, self.annotation_files ) # Common case of a non-weighted dataset. if not isinstance(dataset, list): return tf.data.Dataset.zip({"examples": dataset, **all_annotation_datasets}) # Otherwise, there should be as many annotations datasets as input datasets. datasets = dataset for name, annotation_datasets in all_annotation_datasets.items(): num_annotation_datasets = ( len(annotation_datasets) if isinstance(annotation_datasets, list) else 1 ) if num_annotation_datasets != len(datasets): raise ValueError( "%d '%s' files were provided, but %d were expected to match the " "number of data files" % (num_annotation_datasets, name, len(datasets)) ) # Convert dict of lists to list of dicts. all_annotation_datasets = [ dict(zip(all_annotation_datasets, t)) for t in zip(*all_annotation_datasets.values()) ] return [ tf.data.Dataset.zip({"examples": dataset, **annotation_datasets}) for dataset, annotation_datasets in zip(datasets, all_annotation_datasets) ]
[docs] def get_dataset_size(self, data_file): size = super().get_dataset_size(data_file) if size is not None: for annotation, path in self.annotation_files.items(): annotation_size = tf.nest.map_structure(misc.count_lines, path) if size != annotation_size: raise RuntimeError( "Annotation dataset '%s' does not have the same size as " "the examples dataset" % annotation ) return size
[docs] def make_features(self, element=None, features=None, training=None): if training and self.annotation_files: annotations = element.copy() example = annotations.pop("examples") else: annotations = {} example = element features, labels = super().make_features( element=example, features=features, training=training ) # Load each annotation into the features and labels dict. for name, annotation in annotations.items(): features, labels = self.accepted_annotations[name]( features, labels, annotation ) return features, labels
[docs] def make_inference_dataset( self, features_file, batch_size, batch_type="examples", length_bucket_width=None, num_threads=1, prefetch_buffer_size=None, ): return self.features_inputter.make_inference_dataset( features_file, batch_size, batch_type=batch_type, length_bucket_width=length_bucket_width, num_threads=num_threads, prefetch_buffer_size=prefetch_buffer_size, )
def _get_dataset_transforms( inputter, num_threads=None, training=None, prepare_batch_size=128, ): transform_fns = [] if inputter.has_prepare_step(): prepare_fn = lambda *arg: inputter.prepare_elements( misc.item_or_tuple(arg), training=training ) transform_fns.extend( [ lambda dataset: dataset.batch(prepare_batch_size), lambda dataset: dataset.map(prepare_fn, num_parallel_calls=num_threads), lambda dataset: dataset.unbatch(), ] ) map_fn = lambda *arg: inputter.make_features( element=misc.item_or_tuple(arg), training=training ) transform_fns.append( lambda dataset: dataset.map(map_fn, num_parallel_calls=num_threads) ) return transform_fns