Add a Decoder
1. Add a new decoder class¶
Source code for decoders lives under ludwig/decoders/
.
Decoders are grouped into modules by their output feature type. For instance, all new sequence decoders should be added
to ludwig/decoders/sequence_decoders.py
.
Note
A decoder may support multiple output types, if so it should be defined in the module corresponding to its most
generic supported type. If a decoder is generic with respect to output type, add it to
ludwig/decoders/generic_decoders.py
.
To create a new decoder:
- Define a new decoder class. Inherit from
ludwig.decoders.base.Decoder
or one of its subclasses. - Create all layers and state in the
__init__
method, after callingsuper().__init__()
. - Implement your decoder's forward pass in
def forward(self, combiner_outputs, **kwargs):
. - Define a schema class.
Note: Decoder
inherits from LudwigModule
, which is itself a torch.nn.Module,
so all the usual concerns of developing Torch modules apply.
All decoder parameters should be provided as keyword arguments to the constructor, and must have a default value.
For example the SequenceGeneratorDecoder
decoder takes the following list of parameters in its constructor:
from ludwig.constants import SEQUENCE, TEXT
from ludwig.decoders.base import Decoder
from ludwig.decoders.registry import register_decoder
@register_decoder("generator", [SEQUENCE, TEXT])
class SequenceGeneratorDecoder(Decoder):
def __init__(
self,
vocab_size: int,
max_sequence_length: int,
cell_type: str = "gru",
input_size: int = 256,
reduce_input: str = "sum",
num_layers: int = 1,
**kwargs,
):
super().__init__()
# Initialize any modules, layers, or variable state
2. Implement forward
¶
Actual computation of activations takes place inside the forward
method of the decoder.
All decoders should have the following signature:
def forward(self, combiner_outputs, **kwargs):
# perform forward pass
# combiner_hidden_output = combiner_outputs[HIDDEN]
# ...
# logits = result of decoder forward pass
return {LOGITS: logits}
Inputs
- combiner_outputs (Dict[str, torch.Tensor]): The input tensor, which is the output of a combiner or the combination
of combiner and the activations of any dependent output decoders. The dictionary of combiner outputs includes a tensor of shape
b x h
, whereb
is the batch size andh
is the embedding size, or a sequence of embeddingsb x s x h
wheres
is the sequence length.
Return
- (Dict[str, torch.Tensor]): A dictionary of decoder output tensors. Typical decoders will return values for the keys
LOGITS
,PREDICTION
, or both (defined inludwig.constants
).
3. Add the new decoder class to the corresponding decoder registry¶
Mapping between decoder names in the model definition and decoder classes is made by registering the class in a decoder
registry. The decoder registry is defined in ludwig/decoders/registry.py
. To register your class,
add the @register_decoder
decorator on the line above its class definition, specifying the name of the decoder and a
list of supported output feature types:
@register_decoder("generator", [SEQUENCE, TEXT])
class SequenceGeneratorDecoder(Decoder):
4. Define a schema class¶
In order to ensure that user config validation for your custom defined decoder functions as desired, we need to define a
schema class to go along with the newly defined decoder. To do this, we use a marshmallow_dataclass decorator on a class
definition that contains all the inputs to your custom decoder as attributes. For each attribute, we use utility
functions from the ludwig.schema.utils
directory to validate that input. Lastly, we need to put a reference to this
schema class on the custom decoder class. For example:
from marshmallow_dataclass import dataclass
from ludwig.constants import SEQUENCE, TEXT
from ludwig.schema.decoders.base import BaseDecoderConfig
from ludwig.schema.decoders.utils import register_decoder_config
import ludwig.schema.utils as schema_utils
@register_decoder_config("generator", [SEQUENCE, TEXT])
@dataclass
class SequenceGeneratorDecoderConfig(BaseDecoderConfig):
type: str = schema_utils.StringOptions(options=["generator"], default="generator")
vocab_size: int = schema_utils.Integer(default=None, description="")
max_sequence_length: int = schema_utils.Integer(default=None, description="")
cell_type: str = schema_utils.String(default="gru", description="")
input_size: int = schema_utils.Integer(default=256, description="")
reduce_input: str = schema_utils.ReductionOptions(default="sum")
num_layers: int = schema_utils.Integer(default=1, description="")
And lastly you should add a reference to the schema class on the custom decoder:
@staticmethod
def get_schema_cls():
return SequenceGeneratorDecoderConfig