class CTGAN(BaseGANModel):
    """
    Conditional Tabular GAN model.
    Based on the paper https://arxiv.org/abs/1907.00503.
    Args:
        model_parameters: Parameters used to create the CTGAN model.
    """
    __MODEL__ = 'CTGAN'
    def __init__(self, model_parameters: ModelParameters):
        super().__init__(model_parameters)
        if self.batch_size % 2 != 0 or self.batch_size % self.pac != 0:
            raise ValueError("The batch size needs to be an even value divisible by the PAC.")
        self._model_parameters = model_parameters
        self._real_data_sampler = None
        self._conditional_sampler = None
        self._generator_model = None
        self._critic_model = None
    @staticmethod
    def _create_generator_model(input_dim, generator_dims, data_dim, metadata, tau):
        """
        Creates the generator model.
        Args:
            input_dim: Input dimensionality.
            generator_dims: Dimensions of each hidden layer.
            data_dim: Output dimensionality.
            metadata: Dataset columns metadata.
            tau: Gumbel-Softmax non-negative temperature.
        """
        input = Input(shape=(input_dim, ))
        x = input
        dim = input_dim
        for layer_dim in generator_dims:
            layer_input = x
            x = Dense(layer_dim,
                      kernel_initializer="random_uniform",
                      bias_initializer="random_uniform")(x)
            x = BatchNormalization(epsilon=1e-5, momentum=0.9)(x)
            x = ReLU()(x)
            x = Concatenate(axis=1)([x, layer_input])
            dim += layer_dim
        def _gumbel_softmax(logits, tau=1.0):
            """Applies the Gumbel-Softmax function to the given logits."""
            gumbel_dist = tfp.distributions.Gumbel(loc=0, scale=1)
            gumbels = gumbel_dist.sample(tf.shape(logits))
            gumbels = (logits + gumbels) / tau
            return tf.nn.softmax(gumbels, -1)
        def _generator_activation(data):
            """Custom activation function for the generator model."""
            data_transformed = []
            for col_md in metadata:
                if col_md.discrete:
                    logits = data[:, col_md.start_idx:col_md.end_idx]
                    data_transformed.append(_gumbel_softmax(logits, tau=tau))
                else:
                    data_transformed.append(tf.math.tanh(data[:, col_md.start_idx:col_md.start_idx+1]))
                    logits = data[:, col_md.start_idx+1:col_md.end_idx]
                    data_transformed.append(_gumbel_softmax(logits, tau=tau))
            return data, tf.concat(data_transformed, axis=1)
        x = Dense(data_dim, kernel_initializer="random_uniform",
                  bias_initializer="random_uniform", 
                  activation=_generator_activation)(x)
        return Model(inputs=input, outputs=x)
    @staticmethod
    def _create_critic_model(input_dim, critic_dims, pac):
        """
        Creates the critic model.
        Args:
            input_dim: Input dimensionality.
            critic_dims: Dimensions of each hidden layer.
            pac: PAC size.
        """
        input = Input(shape=(input_dim,))
        x = tf.reshape(input, [-1, input_dim * pac])
        for dim in critic_dims:
            x = Dense(dim,
                      kernel_initializer="random_uniform",
                      bias_initializer="random_uniform")(x)
            x = LeakyReLU(0.2)(x)
            x = Dropout(0.5)(x)
        x = Dense(1, kernel_initializer="random_uniform",
                  bias_initializer="random_uniform")(x)
        return Model(inputs=input, outputs=x)
    def fit(self, data: DataFrame, train_arguments: TrainParameters, num_cols: list[str], cat_cols: list[str]):
        """
        Fits the CTGAN model.
        Args:
            data: A pandas DataFrame with the data to be synthesized.
            train_arguments: CTGAN training arguments.
            num_cols: List of columns to be handled as numerical
            cat_cols: List of columns to be handled as categorical
        """
        super().fit(data=data, num_cols=num_cols, cat_cols=cat_cols, train_arguments=train_arguments)
        self._generator_optimizer = tf.keras.optimizers.Adam(
            learning_rate=self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
        self._critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)
        train_data = self.processor.transform(data)
        metadata = self.processor.metadata
        data_dim = self.processor.output_dimensions
        self._real_data_sampler = RealDataSampler(train_data, metadata)
        self._conditional_sampler = ConditionalSampler(train_data, metadata, train_arguments.log_frequency)
        gen_input_dim = self.latent_dim + self._conditional_sampler.output_dimensions
        self._generator_model = self._create_generator_model(
            gen_input_dim, self.generator_dims, data_dim, metadata, self.tau)
        crt_input_dim = data_dim + self._conditional_sampler.output_dimensions
        self._critic_model = self._create_critic_model(crt_input_dim, self.critic_dims, self.pac)
        self._generator_model.build((self.batch_size, gen_input_dim))
        self._critic_model.build((self.batch_size, crt_input_dim))
        steps_per_epoch = max(len(train_data) // self.batch_size, 1)
        for epoch in range(train_arguments.epochs):
            for _ in range(steps_per_epoch):
                fake_z = tf.random.normal([self.batch_size, self.latent_dim])
                cond_vector = self._conditional_sampler.sample(self.batch_size)
                if cond_vector is None:
                    real = self._real_data_sampler.sample(self.batch_size)
                else:
                    cond, _, col_idx, opt_idx = cond_vector
                    cond = tf.convert_to_tensor(cond)
                    fake_z = tf.concat([fake_z, cond], 1)
                    perm = np.arange(self.batch_size)
                    np.random.shuffle(perm)
                    real = self._real_data_sampler.sample_col(col_idx[perm], opt_idx[perm])
                    cond_perm = tf.gather(cond, perm)
                fake, fake_act = self._generator_model(fake_z, training=True)
                real = tf.convert_to_tensor(real.astype('float32'))
                real_cat = real if cond_vector is None else tf.concat([real, cond_perm], 1)
                fake_cat = fake if cond_vector is None else tf.concat([fake_act, cond], 1)
                critic_loss = self._train_critic_step(real_cat, fake_cat)
                fake_z = tf.random.normal([self.batch_size, self.latent_dim])
                cond_vector = self._conditional_sampler.sample(self.batch_size)
                if cond_vector is None:
                    generator_loss = self._train_generator_step(fake_z)
                else:
                    cond, mask, _, _ = cond_vector
                    cond = tf.convert_to_tensor(cond)
                    mask = tf.convert_to_tensor(mask)
                    fake_z = tf.concat([fake_z, cond], axis=1)
                    generator_loss = self._train_generator_step(fake_z, cond, mask, metadata)
            print(f"Epoch: {epoch} | critic_loss: {critic_loss} | generator_loss: {generator_loss}")
    def _train_critic_step(self, real, fake):
        """
        Single training iteration of the critic model.
        Args:
            real: Real data.
            fake: Fake data.
        """
        with tf.GradientTape() as tape:
            y_real = self._critic_model(real, training=True)
            y_fake = self._critic_model(fake, training=True)
            gp = gradient_penalty(
                partial(self._critic_model, training=True), real, fake, ModeGP.CTGAN, self.pac)
            rec_loss = -(tf.reduce_mean(y_real) - tf.reduce_mean(y_fake))
            critic_loss = rec_loss + gp * self.gp_lambda
        gradient = tape.gradient(critic_loss, self._critic_model.trainable_variables)
        self._apply_critic_gradients(gradient, self._critic_model.trainable_variables)
        return critic_loss
    @tf.function
    def _apply_critic_gradients(self, gradient, trainable_variables):
        """
        Updates gradients of the critic model.
        This logic is isolated in order to be optimized as a TF function.
        Args:
            gradient: Gradient.
            trainable_variables: Variables to be updated.
        """
        self._critic_optimizer.apply_gradients(zip(gradient, trainable_variables))
    def _train_generator_step(self, fake_z, cond_vector=None, mask=None, metadata=None):
        """
        Single training iteration of the generator model.
        Args:
            real: Real data.
            fake: Fake data.
            cond_vector: Conditional vector.
            mask: Mask vector.
            metadata: Dataset columns metadata.
        """
        with tf.GradientTape() as tape:
            fake, fake_act = self._generator_model(fake_z, training=True)
            if cond_vector is not None:
                y_fake = self._critic_model(
                    tf.concat([fake_act, cond_vector], 1), training=True)
                cond_loss = ConditionalLoss.compute(fake, cond_vector, mask, metadata)
                generator_loss = -tf.reduce_mean(y_fake) + cond_loss
            else:
                y_fake = self._critic_model(fake_act, training=True)
                generator_loss = -tf.reduce_mean(y_fake)
        gradient = tape.gradient(generator_loss, self._generator_model.trainable_variables)
        gradient = [gradient[i] + self.l2_scale * self._generator_model.trainable_variables[i] for i in range(len(gradient))]
        self._apply_generator_gradients(gradient, self._generator_model.trainable_variables)
        return generator_loss
    @tf.function
    def _apply_generator_gradients(self, gradient, trainable_variables):
        """
        Updates gradients of the generator model.
        This logic is isolated in order to be optimized as a TF function.
        Args:
            gradient: Gradient.
            trainable_variables: Variables to be updated.
        """
        self._generator_optimizer.apply_gradients(zip(gradient, trainable_variables))
    def sample(self, n_samples: int):
        """
        Samples new data from the CTGAN.
        Args:
            n_samples: Number of samples to be generated.
        """
        if n_samples <= 0:
            raise ValueError("Invalid number of samples.")
        steps = n_samples // self.batch_size + 1
        data = []
        for _ in tf.range(steps):
            fake_z = tf.random.normal([self.batch_size, self.latent_dim])
            cond_vec = self._conditional_sampler.sample(self.batch_size, from_active_bits=True)
            if cond_vec is not None:
                cond = tf.constant(cond_vec)
                fake_z = tf.concat([fake_z, cond], 1)
            fake = self._generator_model(fake_z)[1]
            data.append(fake.numpy())
        data = np.concatenate(data, 0)
        data = data[:n_samples]
        return self.processor.inverse_transform(data)
    def save(self, path):
        """
        Save the CTGAN model in a pickle file.
        Only the required components to sample new data are saved.
        Args:
            path: Path of the pickle file.
        """
        dump({
            "model_parameters": self._model_parameters,
            "data_dim": self.processor.output_dimensions,
            "gen_input_dim": self.latent_dim + self._conditional_sampler.output_dimensions,
            "generator_dims": self.generator_dims,
            "tau": self.tau,
            "metadata": self.processor.metadata,
            "batch_size": self.batch_size,
            "latent_dim": self.latent_dim,
            "conditional_sampler": self._conditional_sampler.__dict__,
            "generator_model_weights": self._generator_model.get_weights(),
            "processor": self.processor.__dict__
        }, path)
    @staticmethod
    def load(class_dict):
        """
        Load the CTGAN model from a pickle file.
        Only the required components to sample new data are loaded.
        Args:
            class_dict: Class dict loaded from the pickle file.
        """
        new_instance = CTGAN(class_dict["model_parameters"])
        setattr(new_instance, "generator_dims", class_dict["generator_dims"])
        setattr(new_instance, "tau", class_dict["tau"])
        setattr(new_instance, "batch_size", class_dict["batch_size"])
        setattr(new_instance, "latent_dim", class_dict["latent_dim"])
        new_instance._conditional_sampler = ConditionalSampler()
        new_instance._conditional_sampler.__dict__ = class_dict["conditional_sampler"]
        new_instance.processor = CTGANDataProcessor()
        new_instance.processor.__dict__ = class_dict["processor"]
        new_instance._generator_model = new_instance._create_generator_model(
            class_dict["gen_input_dim"], class_dict["generator_dims"], 
            class_dict["data_dim"], class_dict["metadata"], class_dict["tau"])
        new_instance._generator_model.build((class_dict["batch_size"], class_dict["gen_input_dim"]))
        new_instance._generator_model.set_weights(class_dict['generator_model_weights'])
        return new_instance