Skip to content

GAN

Bases: BaseModel

Base class of GAN synthesizer models. The main methods are train (for fitting the synthesizer), save/load and sample (obtain synthetic records).

Parameters:

Name Type Description Default
model_parameters ModelParameters

Set of architectural parameters for model definition.

required
Source code in /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/base.py
@typechecked
class BaseGANModel(BaseModel):
    """
    Base class of GAN synthesizer models.
    The main methods are train (for fitting the synthesizer), save/load and sample (obtain synthetic records).
    Args:
        model_parameters (ModelParameters):
            Set of architectural parameters for model definition.
    """
    def __init__(
            self,
            model_parameters: ModelParameters
    ):
        gpu_devices = tfconfig.list_physical_devices('GPU')
        if len(gpu_devices) > 0:
            try:
                tfconfig.experimental.set_memory_growth(gpu_devices[0], True)
            except (ValueError, RuntimeError):
                # Invalid device or cannot modify virtual devices once initialized.
                pass
        #Validate the provided model parameters
        if model_parameters.betas is not None:
            assert len(model_parameters.betas) == 2, "Please provide the betas information as a tuple."

        self.batch_size = model_parameters.batch_size
        self._set_lr(model_parameters.lr)
        self.beta_1 = model_parameters.betas[0]
        self.beta_2 = model_parameters.betas[1]
        self.noise_dim = model_parameters.noise_dim
        self.data_dim = None
        self.layers_dim = model_parameters.layers_dim

        # Additional parameters for the CTGAN
        self.generator_dims = model_parameters.generator_dims
        self.critic_dims = model_parameters.critic_dims
        self.l2_scale = model_parameters.l2_scale
        self.latent_dim = model_parameters.latent_dim
        self.gp_lambda = model_parameters.gp_lambda
        self.pac = model_parameters.pac

        self.processor=None
        if self.__MODEL__ in RegularModels.__members__ or \
            self.__MODEL__ == CTGANDataProcessor.SUPPORTED_MODEL:
            self.tau = model_parameters.tau_gs

    # pylint: disable=E1101
    def __call__(self, inputs, **kwargs):
        return self.model(inputs=inputs, **kwargs)

    # pylint: disable=C0103
    def _set_lr(self, lr):
        if isinstance(lr, float):
            self.g_lr=lr
            self.d_lr=lr
        elif isinstance(lr,(list, tuple)):
            assert len(lr)==2, "Please provide a two values array for the learning rates or a float."
            self.g_lr=lr[0]
            self.d_lr=lr[1]

    def define_gan(self):
        """Define the trainable model components.

        Optionally validate model structure with mock inputs and initialize optimizers."""
        raise NotImplementedError

    @property
    def model_parameters(self):
        "Returns the parameters of the model."
        return self._model_parameters

    @property
    def model_name(self):
        "Returns the model (class) name."
        return self.__class__.__name__

    def fit(self,
              data: Union[DataFrame, array],
              num_cols: Optional[List[str]] = None,
              cat_cols: Optional[List[str]] = None,
              train_arguments: Optional[TrainParameters] = None) -> Union[DataFrame, array]:
        """
        Trains and fit a synthesizer model to a given input dataset.

        Args:
            data (Union[DataFrame, array]): Training data
            num_cols (Optional[List[str]]) : List with the names of the categorical columns
            cat_cols (Optional[List[str]]): List of names of categorical columns
            train_arguments (Optional[TrainParameters]): Training parameters

        Returns:
            Fitted synthesizer
        """
        if self.__MODEL__ in RegularModels.__members__:
            self.processor = RegularDataProcessor(num_cols=num_cols, cat_cols=cat_cols).fit(data)
        elif self.__MODEL__ in TimeSeriesModels.__members__:
            self.processor = TimeSeriesDataProcessor(num_cols=num_cols, cat_cols=cat_cols).fit(data)
        elif self.__MODEL__ == CTGANDataProcessor.SUPPORTED_MODEL:
            n_clusters = train_arguments.n_clusters
            epsilon = train_arguments.epsilon
            self.processor = CTGANDataProcessor(n_clusters=n_clusters, epsilon=epsilon, 
                                                num_cols=num_cols, cat_cols=cat_cols).fit(data)
        elif self.__MODEL__ == DoppelGANgerProcessor.SUPPORTED_MODEL:
            measurement_cols = train_arguments.measurement_cols
            sequence_length = train_arguments.sequence_length
            self.processor = DoppelGANgerProcessor(num_cols=num_cols, cat_cols=cat_cols,
                                                   measurement_cols=measurement_cols,
                                                   sequence_length=sequence_length).fit(data)
        else:
            print(f'A DataProcessor is not available for the {self.__MODEL__}.')

    def sample(self, n_samples: int):
        """
        Generates samples from the trained synthesizer.

        Args:
            n_samples (int): Number of rows to generated.

        Returns:
            synth_sample (pandas.DataFrame): generated synthetic samples.
        """
        steps = n_samples // self.batch_size + 1
        data = []
        for _ in tqdm.trange(steps, desc='Synthetic data generation'):
            z = random.uniform([self.batch_size, self.noise_dim], dtype=tf.dtypes.float32)
            records = self.generator(z, training=False).numpy()
            data.append(records)
        return self.processor.inverse_transform(array(vstack(data)))

    def save(self, path):
        """
        Saves a synthesizer as a pickle.

        Args:
            path (str): Path to write the synthesizer as a pickle object.
        """
        #Save only the generator?
        if self.__MODEL__=='WGAN' or self.__MODEL__=='WGAN_GP' or self.__MODEL__=='CWGAN_GP':
            del self.critic
        make_keras_picklable()
        dump(self, path)

    @classmethod
    def load(cls, path):
        """
        Loads a saved synthesizer from a pickle.

        Args:
            path (str): Path to read the synthesizer pickle from.
        """
        gpu_devices = tfconfig.list_physical_devices('GPU')
        if len(gpu_devices) > 0:
            try:
                tfconfig.experimental.set_memory_growth(gpu_devices[0], True)
            except (ValueError, RuntimeError):
                # Invalid device or cannot modify virtual devices once initialized.
                pass
        synth = load(path)
        return synth

model_name property

Returns the model (class) name.

model_parameters property

Returns the parameters of the model.

define_gan()

Define the trainable model components.

Optionally validate model structure with mock inputs and initialize optimizers.

Source code in /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/base.py
def define_gan(self):
    """Define the trainable model components.

    Optionally validate model structure with mock inputs and initialize optimizers."""
    raise NotImplementedError

fit(data, num_cols=None, cat_cols=None, train_arguments=None)

Trains and fit a synthesizer model to a given input dataset.

Parameters:

Name Type Description Default
data Union[DataFrame, array]

Training data

required
num_cols Optional[List[str]])

List with the names of the categorical columns

None
cat_cols Optional[List[str]]

List of names of categorical columns

None
train_arguments Optional[TrainParameters]

Training parameters

None

Returns:

Type Description
Union[DataFrame, array]

Fitted synthesizer

Source code in /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/base.py
def fit(self,
          data: Union[DataFrame, array],
          num_cols: Optional[List[str]] = None,
          cat_cols: Optional[List[str]] = None,
          train_arguments: Optional[TrainParameters] = None) -> Union[DataFrame, array]:
    """
    Trains and fit a synthesizer model to a given input dataset.

    Args:
        data (Union[DataFrame, array]): Training data
        num_cols (Optional[List[str]]) : List with the names of the categorical columns
        cat_cols (Optional[List[str]]): List of names of categorical columns
        train_arguments (Optional[TrainParameters]): Training parameters

    Returns:
        Fitted synthesizer
    """
    if self.__MODEL__ in RegularModels.__members__:
        self.processor = RegularDataProcessor(num_cols=num_cols, cat_cols=cat_cols).fit(data)
    elif self.__MODEL__ in TimeSeriesModels.__members__:
        self.processor = TimeSeriesDataProcessor(num_cols=num_cols, cat_cols=cat_cols).fit(data)
    elif self.__MODEL__ == CTGANDataProcessor.SUPPORTED_MODEL:
        n_clusters = train_arguments.n_clusters
        epsilon = train_arguments.epsilon
        self.processor = CTGANDataProcessor(n_clusters=n_clusters, epsilon=epsilon, 
                                            num_cols=num_cols, cat_cols=cat_cols).fit(data)
    elif self.__MODEL__ == DoppelGANgerProcessor.SUPPORTED_MODEL:
        measurement_cols = train_arguments.measurement_cols
        sequence_length = train_arguments.sequence_length
        self.processor = DoppelGANgerProcessor(num_cols=num_cols, cat_cols=cat_cols,
                                               measurement_cols=measurement_cols,
                                               sequence_length=sequence_length).fit(data)
    else:
        print(f'A DataProcessor is not available for the {self.__MODEL__}.')

load(path) classmethod

Loads a saved synthesizer from a pickle.

Parameters:

Name Type Description Default
path str

Path to read the synthesizer pickle from.

required
Source code in /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/base.py
@classmethod
def load(cls, path):
    """
    Loads a saved synthesizer from a pickle.

    Args:
        path (str): Path to read the synthesizer pickle from.
    """
    gpu_devices = tfconfig.list_physical_devices('GPU')
    if len(gpu_devices) > 0:
        try:
            tfconfig.experimental.set_memory_growth(gpu_devices[0], True)
        except (ValueError, RuntimeError):
            # Invalid device or cannot modify virtual devices once initialized.
            pass
    synth = load(path)
    return synth

sample(n_samples)

Generates samples from the trained synthesizer.

Parameters:

Name Type Description Default
n_samples int

Number of rows to generated.

required

Returns:

Name Type Description
synth_sample pandas.DataFrame

generated synthetic samples.

Source code in /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/base.py
def sample(self, n_samples: int):
    """
    Generates samples from the trained synthesizer.

    Args:
        n_samples (int): Number of rows to generated.

    Returns:
        synth_sample (pandas.DataFrame): generated synthetic samples.
    """
    steps = n_samples // self.batch_size + 1
    data = []
    for _ in tqdm.trange(steps, desc='Synthetic data generation'):
        z = random.uniform([self.batch_size, self.noise_dim], dtype=tf.dtypes.float32)
        records = self.generator(z, training=False).numpy()
        data.append(records)
    return self.processor.inverse_transform(array(vstack(data)))

save(path)

Saves a synthesizer as a pickle.

Parameters:

Name Type Description Default
path str

Path to write the synthesizer as a pickle object.

required
Source code in /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/base.py
def save(self, path):
    """
    Saves a synthesizer as a pickle.

    Args:
        path (str): Path to write the synthesizer as a pickle object.
    """
    #Save only the generator?
    if self.__MODEL__=='WGAN' or self.__MODEL__=='WGAN_GP' or self.__MODEL__=='CWGAN_GP':
        del self.critic
    make_keras_picklable()
    dump(self, path)