Skip to content

WGAN_GP

Bases: BaseGANModel

Source code in /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/regular/wgangp/model.py
class WGAN_GP(BaseGANModel):

    __MODEL__='WGAN_GP'

    def __init__(self, model_parameters, n_generator:int=1, n_critic:int=1, gradient_penalty_weight:int=10):
        # As recommended in WGAN paper - https://arxiv.org/abs/1701.07875
        # WGAN-GP - WGAN with Gradient Penalty
        self.n_critic = n_critic
        self.n_generator = n_generator
        self.gradient_penalty_weight = gradient_penalty_weight
        super().__init__(model_parameters)

    def define_gan(self, activation_info: Optional[NamedTuple] = None):
        """Define the trainable model components.

        Args:
            activation_info (Optional[NamedTuple], optional): Defaults to None.

        Returns:
            (generator_optimizer, critic_optimizer): Generator and critic optimizers.
        """
        self.generator = Generator(self.batch_size). \
            build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
                        activation_info=activation_info, tau = self.tau)

        self.critic = Critic(self.batch_size). \
            build_model(input_shape=(self.data_dim,), dim=self.layers_dim)

        g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
        c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
        return g_optimizer, c_optimizer

    def gradient_penalty(self, real, fake):
        """Compute gradient penalty.

        Args:
            real: real event.
            fake: fake event.
        Returns:
            gradient_penalty.
        """
        epsilon = tf.random.uniform([real.shape[0], 1], minval=0.0, maxval=1.0, dtype=tf.dtypes.float32)
        x_hat = epsilon * real + (1 - epsilon) * fake
        with tf.GradientTape() as t:
            t.watch(x_hat)
            d_hat = self.critic(x_hat)
        gradients = t.gradient(d_hat, x_hat)
        ddx = tf.sqrt(tf.reduce_sum(gradients ** 2))
        d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)
        return d_regularizer

    @tf.function
    def update_gradients(self, x, g_optimizer, c_optimizer):
        """Compute and apply the gradients for both the Generator and the Critic.

        Args:
            x: real data event
            g_optimizer: generator optimizer
            c_optimizer: critic optimizer
        Returns:
            (critic loss, generator loss)
        """
        for _ in range(self.n_critic):
            with tf.GradientTape() as d_tape:
                critic_loss = self.c_lossfn(x)
            # Get the gradients of the critic
            d_gradient = d_tape.gradient(critic_loss, self.critic.trainable_variables)
            # Update the weights of the critic using the optimizer
            c_optimizer.apply_gradients(
                zip(d_gradient, self.critic.trainable_variables)
            )

        ##Add here the n_generator
        # Update the generator
        for _ in range(self.n_generator):
            with tf.GradientTape() as g_tape:
                gen_loss = self.g_lossfn(x)
            # Get the gradients of the generator
            gen_gradients = g_tape.gradient(gen_loss, self.generator.trainable_variables)
            # Update the weights of the generator
            g_optimizer.apply_gradients(
                zip(gen_gradients, self.generator.trainable_variables)
            )

        return critic_loss, gen_loss

    def c_lossfn(self, real):
        """Compute critic loss.

        Args:
            real: real data

        Returns:
            critic loss
        """
        # generating noise from a uniform distribution
        noise = tf.random.normal([real.shape[0], self.noise_dim], dtype=tf.dtypes.float32)
        # run noise through generator
        fake = self.generator(noise)
        # discriminate x and x_gen
        logits_real = self.critic(real)
        logits_fake = self.critic(fake)

        # gradient penalty
        gp = self.gradient_penalty(real, fake)
        # getting the loss of the critic.
        c_loss = (tf.reduce_mean(logits_fake)
                  - tf.reduce_mean(logits_real)
                  + gp * self.gradient_penalty_weight)
        return c_loss

    def g_lossfn(self, real):
        """Compute generator loss.

        Args:
            real: A real sample
            fake: A fake sample
            fak2: A second fake sample

        Returns:
            Loss of the generator
        """
        # generating noise from a uniform distribution
        noise = tf.random.normal([real.shape[0], self.noise_dim], dtype=tf.dtypes.float32)

        fake = self.generator(noise)
        logits_fake = self.critic(fake)
        g_loss = -tf.reduce_mean(logits_fake)
        return g_loss

    def get_data_batch(self, train, batch_size, seed=0):
        """Get real data batches from the passed data object.

        Args:
            train: real data.
            batch_size: batch size.
            seed (int, optional):Defaults to 0.

        Returns:
            data batch.
        """
        # np.random.seed(seed)
        # x = train.loc[ np.random.choice(train.index, batch_size) ].values
        # iterate through shuffled indices, so every sample gets covered evenly
        start_i = (batch_size * seed) % len(train)
        stop_i = start_i + batch_size
        shuffle_seed = (batch_size * seed) // len(train)
        np.random.seed(shuffle_seed)
        train_ix = np.random.choice(train.shape[0], replace=False, size=len(train))  # wasteful to shuffle every time
        train_ix = list(train_ix) + list(train_ix)  # duplicate to cover ranges past the end of the set
        return train[train_ix[start_i: stop_i]]

    def train_step(self, train_data, optimizers):
        """Perform a training step.

        Args:
            train_data: training data
            optimizers: generator and critic optimizers 

        Returns:
            (critic_loss, generator_loss): Critic and generator loss.
        """
        cri_loss, ge_loss = self.update_gradients(train_data, *optimizers)
        return cri_loss, ge_loss

    def fit(self, data, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str]):
        """Fit a synthesizer model to a given input dataset.

        Args:
            data: A pandas DataFrame or a Numpy array with the data to be synthesized.
            train_arguments: GAN training arguments.
            num_cols (List[str]): List of columns of the data object to be handled as numerical.
            cat_cols (List[str]): List of columns of the data object to be handled as categorical.
        """
        super().fit(data, num_cols, cat_cols)

        processed_data = self.processor.transform(data)
        self.data_dim = processed_data.shape[1]
        optimizers = self.define_gan(self.processor.col_transform_info)

        iterations = int(abs(data.shape[0]/self.batch_size)+1)

        # Create a summary file
        train_summary_writer = tf.summary.create_file_writer(path.join('..\wgan_gp_test', 'summaries', 'train'))

        with train_summary_writer.as_default():
            for epoch in trange(train_arguments.epochs):
                for _ in range(iterations):
                    batch_data = self.get_data_batch(processed_data, self.batch_size).astype(np.float32)
                    cri_loss, ge_loss = self.train_step(batch_data, optimizers)

                print(
                    "Epoch: {} | disc_loss: {} | gen_loss: {}".format(
                        epoch, cri_loss, ge_loss
                    ))

                if epoch % train_arguments.sample_interval == 0:
                    # Test here data generation step
                    # save model checkpoints
                    if path.exists('./cache') is False:
                        os.mkdir('./cache')
                    model_checkpoint_base_name = './cache/' + train_arguments.cache_prefix + '_{}_model_weights_step_{}.h5'
                    self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
                    self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))

c_lossfn(real)

Compute critic loss.

Parameters:

Name Type Description Default
real

real data

required

Returns:

Type Description

critic loss

Source code in /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/regular/wgangp/model.py
def c_lossfn(self, real):
    """Compute critic loss.

    Args:
        real: real data

    Returns:
        critic loss
    """
    # generating noise from a uniform distribution
    noise = tf.random.normal([real.shape[0], self.noise_dim], dtype=tf.dtypes.float32)
    # run noise through generator
    fake = self.generator(noise)
    # discriminate x and x_gen
    logits_real = self.critic(real)
    logits_fake = self.critic(fake)

    # gradient penalty
    gp = self.gradient_penalty(real, fake)
    # getting the loss of the critic.
    c_loss = (tf.reduce_mean(logits_fake)
              - tf.reduce_mean(logits_real)
              + gp * self.gradient_penalty_weight)
    return c_loss

define_gan(activation_info=None)

Define the trainable model components.

Parameters:

Name Type Description Default
activation_info Optional[NamedTuple]

Defaults to None.

None

Returns:

Type Description
(generator_optimizer, critic_optimizer)

Generator and critic optimizers.

Source code in /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/regular/wgangp/model.py
def define_gan(self, activation_info: Optional[NamedTuple] = None):
    """Define the trainable model components.

    Args:
        activation_info (Optional[NamedTuple], optional): Defaults to None.

    Returns:
        (generator_optimizer, critic_optimizer): Generator and critic optimizers.
    """
    self.generator = Generator(self.batch_size). \
        build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
                    activation_info=activation_info, tau = self.tau)

    self.critic = Critic(self.batch_size). \
        build_model(input_shape=(self.data_dim,), dim=self.layers_dim)

    g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
    c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
    return g_optimizer, c_optimizer

fit(data, train_arguments, num_cols, cat_cols)

Fit a synthesizer model to a given input dataset.

Parameters:

Name Type Description Default
data

A pandas DataFrame or a Numpy array with the data to be synthesized.

required
train_arguments TrainParameters

GAN training arguments.

required
num_cols List[str]

List of columns of the data object to be handled as numerical.

required
cat_cols List[str]

List of columns of the data object to be handled as categorical.

required
Source code in /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/regular/wgangp/model.py
def fit(self, data, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str]):
    """Fit a synthesizer model to a given input dataset.

    Args:
        data: A pandas DataFrame or a Numpy array with the data to be synthesized.
        train_arguments: GAN training arguments.
        num_cols (List[str]): List of columns of the data object to be handled as numerical.
        cat_cols (List[str]): List of columns of the data object to be handled as categorical.
    """
    super().fit(data, num_cols, cat_cols)

    processed_data = self.processor.transform(data)
    self.data_dim = processed_data.shape[1]
    optimizers = self.define_gan(self.processor.col_transform_info)

    iterations = int(abs(data.shape[0]/self.batch_size)+1)

    # Create a summary file
    train_summary_writer = tf.summary.create_file_writer(path.join('..\wgan_gp_test', 'summaries', 'train'))

    with train_summary_writer.as_default():
        for epoch in trange(train_arguments.epochs):
            for _ in range(iterations):
                batch_data = self.get_data_batch(processed_data, self.batch_size).astype(np.float32)
                cri_loss, ge_loss = self.train_step(batch_data, optimizers)

            print(
                "Epoch: {} | disc_loss: {} | gen_loss: {}".format(
                    epoch, cri_loss, ge_loss
                ))

            if epoch % train_arguments.sample_interval == 0:
                # Test here data generation step
                # save model checkpoints
                if path.exists('./cache') is False:
                    os.mkdir('./cache')
                model_checkpoint_base_name = './cache/' + train_arguments.cache_prefix + '_{}_model_weights_step_{}.h5'
                self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
                self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))

g_lossfn(real)

Compute generator loss.

Parameters:

Name Type Description Default
real

A real sample

required
fake

A fake sample

required
fak2

A second fake sample

required

Returns:

Type Description

Loss of the generator

Source code in /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/regular/wgangp/model.py
def g_lossfn(self, real):
    """Compute generator loss.

    Args:
        real: A real sample
        fake: A fake sample
        fak2: A second fake sample

    Returns:
        Loss of the generator
    """
    # generating noise from a uniform distribution
    noise = tf.random.normal([real.shape[0], self.noise_dim], dtype=tf.dtypes.float32)

    fake = self.generator(noise)
    logits_fake = self.critic(fake)
    g_loss = -tf.reduce_mean(logits_fake)
    return g_loss

get_data_batch(train, batch_size, seed=0)

Get real data batches from the passed data object.

Parameters:

Name Type Description Default
train

real data.

required
batch_size

batch size.

required
seed int

Defaults to 0.

0

Returns:

Type Description

data batch.

Source code in /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/regular/wgangp/model.py
def get_data_batch(self, train, batch_size, seed=0):
    """Get real data batches from the passed data object.

    Args:
        train: real data.
        batch_size: batch size.
        seed (int, optional):Defaults to 0.

    Returns:
        data batch.
    """
    # np.random.seed(seed)
    # x = train.loc[ np.random.choice(train.index, batch_size) ].values
    # iterate through shuffled indices, so every sample gets covered evenly
    start_i = (batch_size * seed) % len(train)
    stop_i = start_i + batch_size
    shuffle_seed = (batch_size * seed) // len(train)
    np.random.seed(shuffle_seed)
    train_ix = np.random.choice(train.shape[0], replace=False, size=len(train))  # wasteful to shuffle every time
    train_ix = list(train_ix) + list(train_ix)  # duplicate to cover ranges past the end of the set
    return train[train_ix[start_i: stop_i]]

gradient_penalty(real, fake)

Compute gradient penalty.

Parameters:

Name Type Description Default
real

real event.

required
fake

fake event.

required
Source code in /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/regular/wgangp/model.py
def gradient_penalty(self, real, fake):
    """Compute gradient penalty.

    Args:
        real: real event.
        fake: fake event.
    Returns:
        gradient_penalty.
    """
    epsilon = tf.random.uniform([real.shape[0], 1], minval=0.0, maxval=1.0, dtype=tf.dtypes.float32)
    x_hat = epsilon * real + (1 - epsilon) * fake
    with tf.GradientTape() as t:
        t.watch(x_hat)
        d_hat = self.critic(x_hat)
    gradients = t.gradient(d_hat, x_hat)
    ddx = tf.sqrt(tf.reduce_sum(gradients ** 2))
    d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)
    return d_regularizer

train_step(train_data, optimizers)

Perform a training step.

Parameters:

Name Type Description Default
train_data

training data

required
optimizers

generator and critic optimizers

required

Returns:

Type Description
(critic_loss, generator_loss)

Critic and generator loss.

Source code in /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/regular/wgangp/model.py
def train_step(self, train_data, optimizers):
    """Perform a training step.

    Args:
        train_data: training data
        optimizers: generator and critic optimizers 

    Returns:
        (critic_loss, generator_loss): Critic and generator loss.
    """
    cri_loss, ge_loss = self.update_gradients(train_data, *optimizers)
    return cri_loss, ge_loss

update_gradients(x, g_optimizer, c_optimizer)

Compute and apply the gradients for both the Generator and the Critic.

Parameters:

Name Type Description Default
x

real data event

required
g_optimizer

generator optimizer

required
c_optimizer

critic optimizer

required
Source code in /opt/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/ydata_synthetic/synthesizers/regular/wgangp/model.py
@tf.function
def update_gradients(self, x, g_optimizer, c_optimizer):
    """Compute and apply the gradients for both the Generator and the Critic.

    Args:
        x: real data event
        g_optimizer: generator optimizer
        c_optimizer: critic optimizer
    Returns:
        (critic loss, generator loss)
    """
    for _ in range(self.n_critic):
        with tf.GradientTape() as d_tape:
            critic_loss = self.c_lossfn(x)
        # Get the gradients of the critic
        d_gradient = d_tape.gradient(critic_loss, self.critic.trainable_variables)
        # Update the weights of the critic using the optimizer
        c_optimizer.apply_gradients(
            zip(d_gradient, self.critic.trainable_variables)
        )

    ##Add here the n_generator
    # Update the generator
    for _ in range(self.n_generator):
        with tf.GradientTape() as g_tape:
            gen_loss = self.g_lossfn(x)
        # Get the gradients of the generator
        gen_gradients = g_tape.gradient(gen_loss, self.generator.trainable_variables)
        # Update the weights of the generator
        g_optimizer.apply_gradients(
            zip(gen_gradients, self.generator.trainable_variables)
        )

    return critic_loss, gen_loss