Tabular Data Generation using TGAN
I started a project inspired by Omdena's problem Leveraging NLP In Medical Prescription Administration and Information.
The project was interesting and helpful for the community, so I started working on it. The biggest challenge in building any ML model is getting good-quality data. My first stop to get Pharma data was the Google dataset, but I failed to get the right amount of data to build a recommendation system.
So, I thought to use TGAN for synthetic data generation which maintains privacy and helps in creating a meaningful healthcare solution.
Let’s begin!
Synthetic data generation has become an essential tool in scenarios where access to real-world datasets is restricted due to privacy concerns or limited availability. For fields like healthcare, where patient confidentiality is critical, synthetic data provides a viable alternative for training machine learning models, conducting research, and performing exploratory data analysis.
This blog will provide a detailed guide on creating synthetic medical data using Generative Adversarial Networks (GANs), specifically leveraging the TGAN (Tabular GAN) framework.
What Is Synthetic Medical Data?
Synthetic medical data is artificially generated data that mimics the statistical properties and patterns of real medical datasets without containing any real patient information. Such data can include:
- Patient demographics (e.g., age, gender).
- Clinical measurements (e.g., blood pressure, glucose levels).
- Diagnostic codes (e.g., ICD-10 codes).
- Treatment histories.
Overview of TGAN
TGAN is a GAN-based model designed to generate tabular data that contains both continuous and categorical variables. Unlike image-generation GANs, TGAN addresses challenges specific to tabular data, such as multimodal distributions and variable correlations.
Step-by-Step Guide to Creating Synthetic Medical Data
Step 1: Define Data Requirements
- Determine Data Structure:
- Identify the types of variables (e.g., numerical, categorical).
- Define the relationships between variables (e.g., correlations).
- Select a Proxy Dataset:
- Use publicly available datasets, such as:
- MIMIC-III (Electronic Health Records).
- UCI Heart Disease dataset
Step 2: Preprocess the Data
Before training the TGAN model, preprocess the data to ensure compatibility with neural network training.
- Numerical Variables:
- Normalize using Gaussian Mixture Models (GMM) to handle multimodal distributions.
- Scale values to the range (-1, 1) using the tanh activation function.
from sklearn.mixture import GaussianMixture
import numpy as np
def preprocess_numerical(data, n_components=5):
gmm = GaussianMixture(n_components=n_components)
gmm.fit(data.reshape(-1, 1))
cluster_probs = gmm.predict_proba(data.reshape(-1, 1))
normalized_data = (data - gmm.means_.flatten()[np.argmax(cluster_probs, axis=1)]) / \
(2 * np.sqrt(gmm.covariances_.flatten()[np.argmax(cluster_probs, axis=1)]))
return np.clip(normalized_data, -0.99, 0.99), cluster_probs
2. Categorical Variables:
- Convert to one-hot encoding.
- Add random noise to prevent overfitting and improve model training.
def preprocess_categorical(data, noise=0.2):
from sklearn.preprocessing import OneHotEncoder
encoder = OneHotEncoder()
one_hot = encoder.fit_transform(data.reshape(-1, 1)).toarray()
noise_matrix = np.random.uniform(0, noise, one_hot.shape)
noisy_data = one_hot + noise_matrix
return noisy_data / noisy_data.sum(axis=1, keepdims=True)
3. Unified Representation:
- Combine processed numerical and categorical variables into a single representation.
Step 3: Train the TGAN Model
- Model Components:
- Generator: Uses a Long Short-Term Memory (LSTM) network with attention to generating data column-by-column.
- Discriminator: A Multi-Layer Perceptron (MLP) that distinguishes between real and synthetic data.
- Training Objectives:
- Generator Loss: Minimize the KL divergence for discrete variables and improve synthetic data quality.
- Discriminator Loss: Use cross-entropy loss to classify real vs. synthetic data.
- Training Process:
- Use the Adam optimizer to train the model iteratively.
- Monitor training metrics to ensure convergence and stability.
import tensorflow as tf
class TGANGenerator(tf.keras.Model):
def __init__(self, latent_dim, num_features):
super(TGANGenerator, self).__init__()
self.latent_dim = latent_dim
self.lstm = tf.keras.layers.LSTM(128, return_sequences=True)
self.attention = tf.keras.layers.Attention()
self.feature_dense = [
tf.keras.layers.Dense(1, activation='tanh') for _ in range(num_features)
]
def call(self, inputs, prev_outputs):
# LSTM processes the input
x = self.lstm(inputs)
# Apply attention over previous outputs
context = self.attention([x, prev_outputs])
# Generate each feature column
outputs = [dense(context) for dense in self.feature_dense]
return tf.concat(outputs, axis=-1)
class TGANDiscriminator(tf.keras.Model):
def __init__(self):
super(TGANDiscriminator, self).__init__()
self.mlp = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='leaky_relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
def call(self, inputs):
return self.mlp(inputs)
# Instantiate models
latent_dim = 100
num_features = 10 # Adjust based on the dataset
generator = TGANGenerator(latent_dim, num_features)
discriminator = TGANDiscriminator()
Step 4: Generate Synthetic Medical Data
- Once the TGAN model is trained, generate synthetic data by sampling from the generator.
- Adjust the number of samples based on project requirements.
- Ensure that the output retains the same schema (e.g., column names, and data types) as the original dataset.
def generate_synthetic_data(generator, num_samples, latent_dim):
noise = tf.random.normal([num_samples, latent_dim])
return generator(noise).numpy()
synthetic_data = generate_synthetic_data(generator, num_samples=1000, latent_dim=100)
Step 5: Validate the Synthetic Data
Statistical Validation:
- Compare real and synthetic data's summary statistics (mean, median, standard deviation).
def validate_statistics(real_data, synthetic_data):
real_stats = real_data.describe()
synthetic_stats = synthetic_data.describe()
return real_stats, synthetic_stats
- Evaluate pairwise correlations using metrics like normalized mutual information (NMI).
Machine Learning Performance:
- Train models (e.g., Decision Trees, SVMs) on synthetic data and test on real data.
- Ensure performance metrics (e.g., accuracy, F1-score) are comparable.
Privacy Assessment:
- Measure the nearest-neighbour distance between synthetic and real data points to confirm the non-leakage of real data.
Tools for Implementation
- Frameworks:
- TGAN: Open-source implementation of Tabular GAN.
- SDV (Synthetic Data Vault): A library for generating and evaluating synthetic data.
2. Programming Languages:
- Python with libraries such as TensorFlow or PyTorch.
3. Datasets for Training:
- UCI Repository (Heart Disease, Diabetes).
- Kaggle (Healthcare-specific datasets).
- Google Dataset
- MIMIC-III
Benefits of TGAN for Medical Data
- Captures Complex Relationships:
- Preserves correlations between variables.
2. Handles Mixed Data Types:
- Supports both continuous (e.g., lab results) and categorical (e.g., diagnosis) variables.
3. Scalable
- Suitable for large datasets, enabling broader applications.
4. Customizable:
- Allows users to tailor data generation to specific needs (e.g., patient demographics).
Use Case Example
Imagine a project requiring synthetic EHR data:
- Variables:
- Continuous: Age, BMI, cholesterol level.
- Categorical: Gender, diagnosis codes (e.g., ICD-10).
2. Steps:
- Preprocess real EHR data.
- Train TGAN to learn patterns and relationships.
- Generate synthetic records for testing and research.
3. Outcome:
- Synthetic data closely resembles real data in structure and statistical properties.
Conclusion
Using TGAN, researchers and developers can generate high-quality synthetic medical data that retains the statistical properties of real datasets while safeguarding patient privacy. Following the steps outlined above, you can create scalable and privacy-compliant synthetic datasets for a wide range of healthcare applications.
If you have any questions or need guidance on implementation, feel free to leave a comment below!
If you want to learn more about the workings of TGAN click here.