SDV data synthesis(Python)

Loading...

Synthesizing Data with Generative Models for Better MLOps

Generative models are all the rage, and flashy examples have dominated headlines recently -- DALL-E, ChatGPT, diffusion models. But does your business problem require concocting weird art or off-kilter poetry? Unlikely, unfortunately. Yet this new class of approaches, which generates more data from data, has valuable and more prosaic applications.

Given real business data, GANs (Generative Adversarial Networks) and VAEs (variational autoencoders) can produce synthetic data that resembles real data. Is fake data useful? It could be in cases where source data is sensitive and not readily shareable, yet something like the real data is needed for development or testing of pipelines on that data. Perhaps a third-party data science team will develop a new modeling pipeline, but, sharing sensitive data is not possible. Develop on synthetic data!

It's even possible that a bit of synthetic data alongside real data improves modeling outcomes.

This example explores use of the Python library SDV to generate synthetic data resembling a dataset, and then uses Auto ML to assess the quality of models built with synthetic data. SDV uses deep learning, GANs in particular, via its TVAE module.

With this example, you too can exploit generative AI!

Setup

This notebook was run on DBR 12.2 ML, but should work on other recent versions. Use a single-GPU instance if desired (in which case, use 12.2 ML GPU); if not, set use_gpu=False below.

Install SDV along with supporting libraries for visualization. (pandas-profiling needs a small update to pick up a bug fix.)

%pip install "sdv==0.18.0" kaleido "pandas-profiling>=3.6.3"
Show result
use_gpu = True

username = spark.sql("select current_user()").first()['current_user()']
tmp_dir = f"/tmp/{username}"
tmp_experiment_dir = f"/Users/{username}/SDV"

print(f"Using tmp_dir: {tmp_dir}")
print(f"Using tmp_experiment_dir: {tmp_experiment_dir}")
Using tmp_dir: /tmp/sean.owen@databricks.com Using tmp_experiment_dir: /Users/sean.owen@databricks.com/SDV

This example will turn again to our old friend, one of several NYC Taxi-related datasets like the one used in a Kaggle competition. This data is already available in /databricks-datasets in your workspace. It describes a huge number of taxi rides, including their pickup and drop-off time and place, distance, tolls, vendor, etc.

Imagine we wish to predict the tip that the rider will add after the trip, using this data set. (For this reason, total_amount is redacted, as it would be a target leak.) Who knows? maybe this might be used to intelligently suggest a tip amount

It's a huge dataset, so only a small sample will be used. This is, generally, a straightforward regression problem.

# Stick to reliable data in 2009-2016 years
train_df, test_df = spark.read.format("delta").load("/databricks-datasets/nyctaxi/tables/nyctaxi_yellow").\
  filter("YEAR(pickup_datetime) >= 2009 AND YEAR(pickup_datetime) <= 2016").\
  drop("total_amount").\
  sample(0.0005, seed=42).\
  randomSplit([0.9, 0.1], seed=42)
  
table_nyctaxi = train_df.toPandas().sample(frac=1, random_state=42, ignore_index=True) # random shuffle for good measure
table_nyctaxi.head(5)

Building a Synthetic Dataset

Of course, it'd be easy to get started from here and build a regressor, with Auto ML or any standard open-source library, if one had access to the data above.

Imagine that this dataset is sensitive or otherwise not shareable with data science practitioners, yet, these practitioners need to produce a model that accurately predicts tip. It's not as crazy as it sounds; even within an organization, it's possible that important data is tightly controlled and otherwise unavailable for data science experimentation, which is unfortunate for the experimenters.

A First Pass with SDV

This is where a library like SDV (Synthetic Data Vault) come in. SDV is a toolkit for synthesizing data that looks like a given source. It can handle multi-table data architectures with foreign keys, anonymization of PII, and implements sophisticated synthesis techniques based on copulas, GANs and VAEs.

However it also provides a fast and easy-to-use "preset" for simple single-table setups (built on Gaussian copulas, for the interested). Give it a try:

from sdv.metadata.dataset import Metadata
from sdv.lite import TabularPreset

metadata = Metadata()
metadata.add_table(name="nyctaxi_yellow", data=table_nyctaxi)

model = TabularPreset(name='FAST_ML', metadata=metadata.get_table_meta("nyctaxi_yellow"))
model.fit(table_nyctaxi)

model.sample(num_rows=5, randomize_samples=False)
display(spark.createDataFrame(model.sample(num_rows=1000, randomize_samples=False)))
 
vendor_id
pickup_datetime
dropoff_datetime
passenger_count
trip_distance
pickup_longitude
pickup_latitude
rate_code_id
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
CMT
2014-09-14T14:41:55.000+0000
2014-09-14T15:42:22.000+0000
3
4.919124689214862
-50.107404969953365
39.1107023757952
63.715542488737306
CMT
2011-09-29T02:15:42.000+0000
2011-09-29T03:30:36.000+0000
4
5.651105236582846
-52.77979212790649
31.46978463666015
36.34503831002685
2
2015-04-27T05:05:46.000+0000
2015-04-27T05:40:02.000+0000
2
4.1745310998418645
-46.28445319931628
47.28971353951561
78.29942737895574
VTS
2011-12-21T19:17:50.000+0000
2011-12-21T19:46:39.000+0000
2
1.636474162681234
-60.135318243940176
25.923036059989748
23.037879089894844
1
2013-09-20T19:16:49.000+0000
2013-09-20T17:58:10.000+0000
1
1.310946480832447
-80.40890832508705
40.228279873016504
10.985007688074642
VTS
2012-09-29T14:57:30.000+0000
2012-09-29T14:24:27.000+0000
3
6.748624573624246
-75.63531379984808
42.46190373376203
null
VTS
2009-05-25T11:59:33.000+0000
2009-05-25T10:23:44.000+0000
2
0
-103.89126152383061
37.09694884778571
0
CMT
2012-07-17T11:51:04.000+0000
2012-07-17T10:51:46.000+0000
4
5.367839348177938
-34.91838525444925
19.065148761291677
null
CMT
2012-06-28T16:45:55.000+0000
2012-06-28T18:29:29.000+0000
1
8.528665425027247
-82.62295638059683
43.28501720528066
0
CMT
2014-06-29T02:24:48.000+0000
2014-06-29T01:58:00.000+0000
5
1.8072443020018751
-97.68702007587424
45.375511420572565
0
CMT
2007-08-16T22:52:10.000+0000
2007-08-16T23:17:36.000+0000
0
0
-58.85554121149429
31.352917514883174
null
VTS
2012-01-26T08:35:29.000+0000
2012-01-26T09:46:30.000+0000
4
0
-74.54789380919692
51.09221180287972
40.98914408898225
1
2013-07-17T09:58:43.000+0000
2013-07-17T09:54:18.000+0000
1
4.442938400843033
-47.080763289956195
42.2322090112537
69.51719923328842
CMT
2013-02-06T18:38:07.000+0000
2013-02-06T18:21:49.000+0000
3
0.21826516005974872
-60.61038893570084
37.815942541178
27.54232720724437
VTS
2010-02-12T23:07:03.000+0000
2010-02-12T23:11:40.000+0000
0
0
-88.91418211303528
37.06201675468668
0
CMT
2011-10-03T00:35:44.000+0000
2011-10-03T02:24:39.000+0000
0
2.0036763370818433
-104.8551903702712
49.62770776309874
0
VTS
2011-10-21T01:45:21.000+0000
2011-10-21T02:59:34.000+0000
2
5.116236149168493
-50.1165389501832
26.870607510914056
22.478441011390302
1,000 rows

At a glance, that looks like plausible real data. A deeper glance reveals some problems, though:

  • Some monetary amounts are negative, like MTA tax or tip
  • Passenger count and distance are 0 sometimes
  • Distance is occasionally impossibly shorter than straight line distance
  • Longitude and latitude are sometimes nowhere near New York City (in some cases entirely invalid, like >90 degrees latitude)
  • Monetary amounts have more than two decimal places
  • Pickup time is occasionally after dropoff time, probably due to daylight savings issues, or sometimes more than a 12-hour shift long

Many of these are actually problems found in the source data too! We can nevertheless proceed to get a report on the data quality, from SDV:

from sdmetrics.reports.single_table import QualityReport

report = QualityReport()
report.generate(table_nyctaxi, 
                model.sample(num_rows=10000, randomize_samples=False), 
                metadata.get_table_meta("nyctaxi_yellow"))

report.get_visualization("Column Pair Trends")
Creating report: 100%|██████████| 4/4 [00:12<00:00, 3.11s/it] Overall Quality Score: 75.18% Properties: Column Shapes: 66.88% Column Pair Trends: 83.47%

The quality is "OK". Pairwise correlations look similar between real and synthetic data; there is some issue with the synthesized store_and_fwd_flag column. Clearly, the synthetic data issues need fixing.

"Quality" is about 75%. This metric is just the average of two quality scores, based on individual column shapes, and column pairs - more or less how much columns, and pairs of columns, in the synthetic data look like real data.

Those data engineers should fix the data problems, really, before data scientists seriously consider synthesizing data from it. Here, these can just be filtered from the input. Later, these observations will also translate into Constraints that SDV can apply to its output as well, such as enforcing that a number is positive or in a range.

Start over and fix the input data by filtering:

from pyspark.sql.functions import col, pandas_udf
import numpy as np

df = spark.read.format("delta").load("/databricks-datasets/nyctaxi/tables/nyctaxi_yellow").\
  filter("YEAR(pickup_datetime) >= 2009 AND YEAR(pickup_datetime) <= 2016").\
  drop("total_amount")

for c in ["fare_amount", "extra", "mta_tax", "tip_amount", "tolls_amount"]:
  df = df.filter(f"{c} >= 0")
df = df.filter("passenger_count > 0")
df = df.filter("trip_distance > 0 AND trip_distance < 100")
df = df.filter("dropoff_datetime > pickup_datetime")
df = df.filter("CAST(dropoff_datetime AS long) < CAST(pickup_datetime AS long) + 12 * 60 * 60")
for c in ["pickup_longitude", "dropoff_longitude"]:
  df = df.filter(f"{c} > -76 AND {c} < -72")
for c in ["pickup_latitude", "dropoff_latitude"]:
  df = df.filter(f"{c} > 39 AND {c} < 43")

# Define this as a standalone function for reuse later
def haversine_dist_miles(from_lat_deg, from_lon_deg, to_lat_deg, to_lon_deg):
  to_lat = np.deg2rad(to_lat_deg)
  to_lon = np.deg2rad(to_lon_deg)
  from_lat = np.deg2rad(from_lat_deg)
  from_lon = np.deg2rad(from_lon_deg)
  # 3958.8 is avg earth radius in miles
  return 3958.8 * 2 * np.arcsin(np.sqrt(
    np.square(np.sin((to_lat - from_lat) / 2)) + np.cos(to_lat) * np.cos(from_lat) * np.square(np.sin((to_lon - from_lon) / 2)))) 
  
@pandas_udf('double')
def haversine_dist_miles_udf(from_lat_deg, from_lon_deg, to_lat_deg, to_lon_deg):
  return haversine_dist_miles(from_lat_deg, from_lon_deg, to_lat_deg, to_lon_deg)

# Allow 90% of min theoretical distance to account for rounding, inaccuracy
df = df.filter(col("trip_distance") >= 
               0.9 * haversine_dist_miles_udf("pickup_latitude", "pickup_longitude", "dropoff_latitude", "dropoff_longitude"))
  
train_df, test_df = df.sample(0.0005, seed=42).randomSplit([0.9, 0.1], seed=42)
train_df.cache()

table_nyctaxi = train_df.toPandas().sample(frac=1, random_state=42, ignore_index=True)

Try again with TabularPreset on the fixed-up data:

metadata = Metadata()
metadata.add_table(name="nyctaxi_yellow", data=table_nyctaxi)

model = TabularPreset(name='FAST_ML', metadata=metadata.get_table_meta("nyctaxi_yellow"))
model.fit(table_nyctaxi)

report = QualityReport()
report.generate(table_nyctaxi, 
                model.sample(num_rows=10000, randomize_samples=False), 
                metadata.get_table_meta("nyctaxi_yellow"))

report.get_visualization("Column Pair Trends")
Creating report: 100%|██████████| 4/4 [00:09<00:00, 2.48s/it] Overall Quality Score: 82.11% Properties: Column Shapes: 78.12% Column Pair Trends: 86.1%

Quality went from 75% to 82%, as it's easier to imitate data that doesn't have odd outliers. Can we do better with some advanced techniques?

Adding Constraints, Variational Autoencoders, and MLflow

Let's jump ahead to add several improvements. Below, SDV's TVAE (Triplet-based Variational Autoencoder) module is used for more sophisticated (and compute-intensive) synthesis of data. This part can and should be accelerated with a GPU. SDV also has CTGAN and CopulaGAN, though these turn out to be less effective on this data. Deep learning, GPUs - this is real-deal AI!

Constraints are also added, per above, to improve the realism of the output. This includes a custom Constraint that checks trip_distance against the straight-line (Haversine) distance between the pickup and dropoff lat/lon, and a custom Constraint limiting the duration of the trip.

from sdv.constraints import FixedIncrements, Inequality, Positive, ScalarInequality, ScalarRange, create_custom_constraint
import numpy as np

# Add constraints mirroring those above

constraints = []

# Distance shouldn't be (too) much less than straight line distance
def is_trip_distance_valid(column_names, data):
  dist_col, from_lat, from_lon, to_lat, to_lon = column_names
  return data[dist_col] >= 0.9 * haversine_dist_miles(data[from_lat], data[from_lon], data[to_lat], data[to_lon])

TripDistanceValid = create_custom_constraint(is_valid_fn=is_trip_distance_valid)
constraints += [TripDistanceValid(column_names=["trip_distance", "pickup_latitude", "pickup_longitude", "dropoff_latitude", "dropoff_longitude"])]

# Dropoff shouldn't be more than 12 hours after pickup, or before pickup
def is_duration_valid(column_names, data):
  pickup_col, dropoff_col = column_names
  return (data[dropoff_col] - data[pickup_col]) < np.timedelta64(12, 'h')

DurationValid = create_custom_constraint(is_valid_fn=is_duration_valid)
constraints += [DurationValid(column_names=["pickup_datetime", "dropoff_datetime"])]
constraints += [Inequality(low_column_name="pickup_datetime", high_column_name="dropoff_datetime")]

# Monetary amounts should be positive
constraints += [ScalarInequality(column_name=c, relation=">=", value=0) for c in 
                 ["fare_amount", "extra", "mta_tax", "tip_amount", "tolls_amount"]]
# Passengers should be a positive integer
constraints += [FixedIncrements(column_name="passenger_count", increment_value=1)]
constraints += [Positive(column_name="passenger_count")]
# Distance should be positive and not (say) more than 100 miles
constraints += [ScalarRange(column_name="trip_distance", low_value=0, high_value=100)]
# Lat/lon should be in some credible range around New York City
constraints += [ScalarRange(column_name=c, low_value=-76, high_value=-72) for c in ["pickup_longitude", "dropoff_longitude"]]
constraints += [ScalarRange(column_name=c, low_value=39, high_value=43) for c in ["pickup_latitude", "dropoff_latitude"]]

Finally, the whole modeling process is also tracked via MLflow - data quality metrics, plots, and even a "model" based on SDV that can be loaded or deployed to generate synthetic data as its output.

import mlflow
from mlflow.models import infer_signature
from sdv.metadata.dataset import Metadata
from sdv.tabular import TVAE
from sdmetrics.reports.single_table import QualityReport
import pandas as pd

# Won't log 'internal' sklearn models fit by SDV
mlflow.autolog(disable=True) 

# Wrapper convenience model that lets the SDV model "predict" new synthetic data
class SynthesizeModel(mlflow.pyfunc.PythonModel):
  def __init__(self, model):
    self.model = model

  def predict(self, context, model_input):
    return self.model.sample(num_rows=len(model_input))

with mlflow.start_run():
  metadata = Metadata()
  metadata.add_table(name="nyctaxi_yellow", data=table_nyctaxi)

  model = TVAE(constraints=constraints, batch_size=1000, epochs=500, cuda=use_gpu)
  model.fit(table_nyctaxi)
  
  sample = model.sample(num_rows=10000, randomize_samples=False)
  report = QualityReport()
  report.generate(table_nyctaxi, sample, metadata.get_table_meta("nyctaxi_yellow"))
  
  # Log metrics and plots with MLflow
  mlflow.log_metric("Quality Score", report.get_score())
  for (prop, score) in report.get_properties().to_numpy().tolist():
    mlflow.log_metric(prop, score)
    mlflow.log_dict(report.get_details(prop).to_dict(orient='records'), f"{prop}.json")
    prop_viz = report.get_visualization(prop)
    display(prop_viz)
    mlflow.log_figure(prop_viz, f"{prop}.png")
  
  # Log wrapper model for synthesis of data, if desired
  # Not strictly necessary; this model's .pkl serialization could have been logged as an artifact,
  # or not at all
  if use_gpu:
    # Assign model to CPU for later inference; GPU not really useful
    model._model.set_device('cpu')
  synthesize_model = SynthesizeModel(model)
  dummy_input = pd.DataFrame([True], columns=["dummy"]) # dummy value
  signature = infer_signature(dummy_input, synthesize_model.predict(None, dummy_input))
  mlflow.pyfunc.log_model("model", python_model=synthesize_model, 
                          registered_model_name="sdv_synth_model",
                          input_example=dummy_input, signature=signature)
Sampling rows: 100%|██████████| 10000/10000 [00:02<00:00, 4218.93it/s] Creating report: 100%|██████████| 4/4 [00:07<00:00, 1.81s/it] Overall Quality Score: 83.25% Properties: Column Shapes: 78.51% Column Pair Trends: 88.0%
Sampling rows: 100%|██████████| 1/1 [00:00<00:00, 1.67it/s] /databricks/python/lib/python3.9/site-packages/mlflow/models/signature.py:131: UserWarning: Hint: Inferred schema contains integer column(s). Integer columns in Python cannot represent missing values. If your input data contains missing values at inference time, it will be encoded as floats and will cause a schema enforcement error. The best way to avoid this problem is to infer the model schema based on a realistic data sample (training dataset) that includes missing values. Alternatively, you can declare integer columns as doubles (float64) whenever these columns may have missing values. See `Handling Integers With Missing Values <https://www.mlflow.org/docs/latest/models.html#handling-integers-with-missing-values>`_ for more details. /databricks/python/lib/python3.9/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils. Successfully registered model 'sdv_synth_model'. 2023/01/25 01:47:12 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation. Model name: sdv_synth_model, version 1 Created version '1' of model 'sdv_synth_model'.

Quality is up slightly, to 83.2%.

Check out the MLflow run linked in the cell output above to see what MLflow captured, and even registered as model sdv_synth_model. MLflow has all of the plots and metrics, and even the data generator as a 'model' for later use.

Also, move this latest version of sdv_synth_model into Production for the next step, in the UI.

Generating the Data

Now, to generate some synthetic data! Above, fitting the generative model that synthesizes data took a while and does not parallelize across a cluster, but, applying it to create data can be neatly parallelized with Spark. It'll be simpler here to just load the original SDV model from MLflow, write a simple function to make data, and then apply it in parallel with Spark:

import mlflow

# Pick out the raw SDV model inside the wrapper
sdv_model = mlflow.pyfunc.load_model("models:/sdv_synth_model/Production")._model_impl.python_model.model

# Simple function to generate data from the model. The input could really be anything; here the input
# is assumed to be the number of rows to generate.
def synthesize_data(how_many_dfs):
  for how_many_df in how_many_dfs:
    # This will generate different data every run, note; can't be seeded, except to make it return
    # the same data every single call!
    # output_file_path='disable' is a workaround (?) for temp file errors
    yield sdv_model.sample(num_rows=how_many_df.sum().item(), output_file_path='disable')

# Generate, for example, the same number of rows as in the input
how_many = len(table_nyctaxi)
partitions = 256
synth_df = spark.createDataFrame([(how_many // partitions,)] * partitions).\
  repartition(partitions).\
  mapInPandas(synthesize_data, schema=df.schema)

display(synth_df)
2023/01/25 02:12:24 WARNING mlflow.pyfunc: Detected one or more mismatches between the model's dependencies and the current Python environment: - dill (current: 0.3.6, required: dill==0.3.4) To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.
 
vendor_id
pickup_datetime
dropoff_datetime
passenger_count
trip_distance
pickup_longitude
pickup_latitude
rate_code_id
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
1
2015-06-24T02:43:12.000+0000
2015-06-24T03:29:09.218+0000
1
11.257726292507787
-73.98383674972531
40.75206864851658
2
CMT
2014-06-29T08:27:37.000+0000
2014-06-29T08:54:55.610+0000
1
8.594514194200348
-73.95613569698985
40.76448091127542
1
2
2015-02-09T20:33:07.000+0000
2015-02-09T20:37:48.190+0000
1
1.268919950384072
-73.99422035802257
40.76467443653145
1
VTS
2010-01-31T01:38:38.000+0000
2010-01-31T01:53:51.550+0000
6
3.614414041882045
-73.97580549129465
40.752030165187584
1
CMT
2011-08-11T23:23:10.000+0000
2011-08-11T23:41:23.207+0000
1
5.311749700102196
-73.97449101915439
40.75241084928244
1
VTS
2010-08-16T06:33:28.000+0000
2010-08-16T06:47:09.721+0000
6
1.9445790184833132
-74.00259895767144
40.748313116934526
1
VTS
2011-10-18T14:24:13.000+0000
2011-10-18T14:33:16.850+0000
1
5.240807487359386
-73.96609458220006
40.761103622590404
1
CMT
2010-02-10T09:06:56.000+0000
2010-02-10T09:42:17.685+0000
1
8.662722168273127
-73.9610986094662
40.75545132168917
1
2
2015-03-17T05:05:01.000+0000
2015-03-17T05:14:03.679+0000
1
1.962235100124891
-73.97541837729017
40.75010098023645
1
VTS
2012-11-07T06:29:22.000+0000
2012-11-07T06:34:14.700+0000
6
1.0672447004318975
-73.98561663035227
40.74080957665742
1
CMT
2012-12-29T15:44:06.000+0000
2012-12-29T16:06:33.953+0000
1
8.18364182560571
-73.87415756122213
40.764541803275264
1
2
2015-01-24T04:11:56.000+0000
2015-01-24T04:25:57.872+0000
1
1.4447368384794657
-73.98358256599344
40.75625640176733
1
VTS
2015-02-28T11:24:56.000+0000
2015-02-28T11:48:12.911+0000
5
4.208505048426531
-73.98572058793143
40.7336649756569
1
VTS
2015-07-22T05:07:21.000+0000
2015-07-22T05:12:11.716+0000
1
1.2806528903120369
-73.99252225380346
40.74673722549067
1
1
2015-07-15T22:31:34.000+0000
2015-07-15T23:02:45.215+0000
1
9.631926815421469
-73.97994812601323
40.75889989900861
1
1
2015-04-02T23:13:57.000+0000
2015-04-02T23:24:28.298+0000
1
3.9781002669162713
-73.99396746786915
40.7414358208701
1
CMT
2010-08-16T03:20:45.000+0000
2010-08-16T03:33:35.680+0000
1
3.1996126811040004
-73.99237253959909
40.75435569427846
1
1,000 rows|Truncated data
synth_data_path = f"{tmp_dir}/synth/nyctaxi_synth"
synth_df.write.format("delta").save(synth_data_path)

Much better! Whole numbers of passengers, money that looks like dollars and cents, reasonable-looking locations.

Comparing Synthetic Data

In Databricks, you can generate a profile of any dataset. Here we want to compare the original and synthetic data sets, to get a sense of how much they match. Use pandas-profiling:

from pandas_profiling import ProfileReport

synth_data_df = spark.read.format("delta").load(synth_data_path).toPandas()

original_report = ProfileReport(table_nyctaxi, title='Original Data', minimal=True)
synth_report = ProfileReport(synth_data_df, title='Synthetic Data', minimal=True)
compare_report = original_report.compare(synth_report)
compare_report.config.html.navbar_show = False
compare_report.config.html.full_width = True
displayHTML(compare_report.to_html())

There is definitely broad similarity in distributions of individual features. There are also some odd differences, like the non-uniformity of pickup/dropoff time in the synthetic data. Addressing this may be a matter of further tuning the synthetic data process, and is out of scope here.

It's worth saying that there are limits to any data synthesis process. For example, here the model can generate pickup/dropoff points that resemble the input's points, in their range and distribution and even in their relation to each other considering the distance. But it has no direct way of knowing whether the points make sense as places on a street to pick up a person.

There isn't a free lunch here, and the generated data at best roughly resembles real data. This is part of the point, of course, that it not be so realistic that real data 'leaks' into the output. How good is it? Let's try to build a model with it. Save the synthetic data set.

Modeling on Synthetic Data

This is the point where the data scientist reenters the picture. We're in a development environment. He or she has access to this synthetic data set now, and the task is to build a model that will predict real fare tips reasonably well. A good start would be to simply use auto ML to fit a reasonable model to the synthetic data:

synth_data_path = f"{tmp_dir}/synth/nyctaxi_synth"
import databricks.automl
databricks.automl.regress(
  spark.read.format("delta").load(synth_data_path),
  target_col="tip_amount",
  primary_metric="rmse",
  experiment_dir=tmp_experiment_dir,
  experiment_name="Synth models",
  timeout_minutes=120)
2023/01/26 14:29:57 INFO databricks.automl.client.manager: AutoML will optimize for root mean squared error metric, which is tracked as val_root_mean_squared_error in the MLflow experiment. 2023/01/26 14:29:58 INFO databricks.automl.client.manager: MLflow Experiment ID: 2408902617877312 2023/01/26 14:29:58 INFO databricks.automl.client.manager: MLflow Experiment: https://e2-demo-field-eng.cloud.databricks.com/?o=1444828305810485#mlflow/experiments/2408902617877312 2023/01/26 14:31:44 INFO databricks.automl.client.manager: Data exploration notebook: https://e2-demo-field-eng.cloud.databricks.com/?o=1444828305810485#notebook/2408902617877330 2023/01/26 15:23:42 INFO databricks.automl.client.manager: AutoML experiment completed successfully.
Out[5]: <databricks.automl.shared.result.AutoMLSummary at 0x7f7bbf3b3a00>

The best model (after a couple hours, at least) had a (test) RMSE of about 1.1 and R-squared of 0.63. Not awful. The real question is, how well does this model trained on synthetic data hold up when applied to real data?

This is the kind of next step that might happen in a staging or testing environment, where the result of the modeling process produced by the data science team is tested before deployment, and this environment should have some real data to test on, before the model faces, well, real data!

A simple snippet of what might transpire is below. Here we actually have loaded real data already, so, evaluate metrics on that:

import mlflow
from sklearn.metrics import mean_squared_error, r2_score
from math import sqrt

test_pd = test_df.toPandas()

def print_metrics(exp_name):
  best_runs = mlflow.search_runs(
    experiment_names=[f"{tmp_experiment_dir}/{exp_name}"], 
    order_by=["metrics.val_root_mean_squared_error"],
    max_results=1)
  run_id = best_runs['run_id'].item()
  model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
  y_pred = model.predict(test_pd.drop("tip_amount", axis=1))
  y_true = test_pd["tip_amount"]
  print(f"RMSE: {sqrt(mean_squared_error(y_true, y_pred))}")
  print(f"R^2:  {r2_score(y_true, y_pred)}")
  
print_metrics("Synth models")
/databricks/python/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:170: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros warnings.warn( RMSE: 1.525253299446487 R^2: 0.48683964004022906

Comparing to Modeling on Original Data

Pretty comparable metrics, actually, which is good news. RMSE and R^2 are 1.52 and 0.49, versus 1.4 and 0.49 estimated from the held-out (synthetic) test set in auto ML. The model performance in this case held up reasonably on real data.

But, wait, how would we have done if we'd fit a model on a roughly equal amount of real data?

import databricks.automl

databricks.automl.regress(
  train_df,
  target_col="tip_amount",
  primary_metric="rmse",
  experiment_dir=tmp_experiment_dir,
  experiment_name="Actual data models",
  timeout_minutes=120)
2023/01/26 03:05:24 INFO databricks.automl.client.manager: AutoML will optimize for root mean squared error metric, which is tracked as val_root_mean_squared_error in the MLflow experiment. 2023/01/26 03:05:25 INFO databricks.automl.client.manager: MLflow Experiment ID: 2408902617760708 2023/01/26 03:05:25 INFO databricks.automl.client.manager: MLflow Experiment: https://e2-demo-field-eng.cloud.databricks.com/?o=1444828305810485#mlflow/experiments/2408902617760708 2023/01/26 03:10:28 INFO databricks.automl.client.manager: Data exploration notebook: https://e2-demo-field-eng.cloud.databricks.com/?o=1444828305810485#notebook/2408902617765065 2023/01/26 03:34:07 INFO databricks.automl.client.manager: AutoML experiment completed successfully.
Out[22]: <databricks.automl.shared.result.AutoMLSummary at 0x7f67e7fd0c70>
print_metrics("Actual data models")
RMSE: 0.9387479716191751 R^2: 0.7784209844478875

Modeling on real data would have done better here for sure. RMSE is 0.93, and R^2 is 0.78, vs 1.5 and 0.49. That's a significant difference. In other cases, it's possible the model on synthetic data is just about as good, but not quite here.

However, the synthetic data modeling process did provide something. Data scientists verified the viability of building a decent model, and the model building approach, without using real data.

Synthetic Data for Testing

One might also think of synthetic data as a tool for testing a model's behavior across a wider range of inputs. Does it fail on some input, or give an outlandish answer? The synthetic data, by nature, won't look extreme compared to real data, but might exercise unusual but realistic inputs not found in training or test data.

For example, synthetic data might be used in some kind of integration test like below, where one looks for predictions that seem simply out of normal ranges.

import mlflow

best_runs = mlflow.search_runs(
  experiment_names=[f"{tmp_experiment_dir}/Synth models"], 
  order_by=["metrics.val_root_mean_squared_error"],
  max_results=1)
run_id = best_runs['run_id'].item()
model_udf = mlflow.pyfunc.spark_udf(spark, f"runs:/{run_id}/model")

synth_df = spark.read.format("delta").load(synth_data_path).drop("tip_amount")
display(synth_df.withColumn("prediction", model_udf(*synth_df.drop("tip_amount").columns)).filter("prediction < 0 OR prediction > 100"))
2023/01/26 01:55:03 WARNING mlflow.pyfunc: Detected one or more mismatches between the model's dependencies and the current Python environment: - cloudpickle (current: 2.2.1, required: cloudpickle==2.0.0) To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file. 2023/01/26 01:55:03 WARNING mlflow.pyfunc: Calling `spark_udf()` with `env_manager="local"` does not recreate the same environment that was used during training, which may lead to errors or inaccurate predictions. We recommend specifying `env_manager="conda"`, which automatically recreates the environment that was used to train the model and performs inference in the recreated environment. 2023/01/26 01:55:03 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'
 
vendor_id
pickup_datetime
dropoff_datetime
passenger_count
trip_distance
pickup_longitude
pickup_latitude
rate_code_id
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
CMT
2015-06-18T04:30:33.000+0000
2015-06-18T04:54:24.511+0000
1
1.951801658694093
-73.9890169489714
40.740827465948215
1
1
2016-03-14T10:28:23.000+0000
2016-03-14T10:32:26.108+0000
1
1.2951380119914435
-73.95866747460303
40.79221820761345
1
CMT
2014-12-12T01:43:32.000+0000
2014-12-12T01:53:24.718+0000
1
1.0121639965746192
-73.98979955411839
40.73664413655854
1
CMT
2015-02-01T18:31:50.000+0000
2015-02-01T18:42:24.482+0000
1
3.328739259876099
-73.96616289601634
40.76482233490716
1
CMT
2014-12-04T15:53:58.000+0000
2014-12-04T16:18:03.965+0000
1
1.6811859159525218
-73.98081043109033
40.759846279674406
1
CMT
2014-09-14T12:40:33.000+0000
2014-09-14T12:47:33.649+0000
1
0.9603600326132996
-73.96678817875939
40.77553440518366
1
CMT
2010-10-15T01:40:17.000+0000
2010-10-15T01:45:20.731+0000
1
1.1659782290462033
-73.97389723806461
40.768956201994705
1
CMT
2010-03-24T11:52:54.000+0000
2010-03-24T12:00:53.111+0000
1
1.4922896530769307
-74.00238733894427
40.74719953141979
1
1
2016-02-05T03:50:25.000+0000
2016-02-05T03:56:01.843+0000
1
1.2188402023235991
-73.96082100986635
40.75459695689523
1
VTS
2010-04-09T09:55:33.000+0000
2010-04-09T10:25:05.624+0000
1
1.8205436309761514
-73.96584283396666
40.76383358310657
1
VTS
2010-05-22T18:48:54.000+0000
2010-05-22T19:12:05.577+0000
1
9.468138116679114
-73.96247698675862
40.76761697478979
null
CMT
2015-05-16T02:01:41.000+0000
2015-05-16T02:06:57.497+0000
1
1.1293176962553038
-73.99582744849307
40.75664204860671
1
CMT
2010-07-31T15:35:24.000+0000
2010-07-31T17:05:07.708+0000
1
16.477592109776783
-73.96222202206611
40.643723910905884
2
CMT
2010-02-14T09:54:26.000+0000
2010-02-14T10:35:28.354+0000
1
15.379819264485887
-73.90970965479362
40.64590021754133
2
2
2015-10-07T12:56:19.000+0000
2015-10-07T13:09:15.358+0000
1
9.825790751201284
-73.97494471914298
40.77140162712815
1
CMT
2010-09-08T21:16:37.000+0000
2010-09-08T21:59:56.936+0000
1
1.887372145320396
-73.96635481145748
40.78848702653357
1
1
2015-09-19T11:30:23.000+0000
2015-09-19T12:32:52.374+0000
1
13.914524319532562
-73.77762829599241
40.64582197754714
2
1,000 rows|Truncated data

Oops, this reveals a problem with this model, though one that would have been visible testing on any data, probably: predicted tips are sometimes less than 0. Really, this regression should have been construed as something like a log-linear model given the distribution of tips (non-zero, likely exponential) to avoid this, but this example will forego pursuing this.

Real and Synthetic Data: Why Don't We Have Both?

In many cases, data scientists do have access to real production data to train models. What use is synthetic data? It can be viewed as a form of data augmentation, generating more real-ish data to fit on. More data usually leads to better models. In theory, nothing really new has entered the picture here. There is no additional real data. Nevertheless, sometimes additional synthetic data does improve the result of a modeling process.

In particular, synthesis is useful when the data set is imbalanced in some way; some types of inputs are rare or missing from the input. Synthesizing data to fill that gap might be especially useful. By definition, it's harder to make realistic data like subsets that are rare!

As a final experiment, try modeling on a mix of real and synthetic data:

import databricks.automl

databricks.automl.regress(
  train_df.union(spark.read.format("delta").load(synth_data_path)),
  target_col="tip_amount",
  primary_metric="rmse",
  experiment_dir=tmp_experiment_dir,
  experiment_name="Hybrid data models",
  timeout_minutes=120)
2023/01/26 03:35:19 INFO databricks.automl.client.manager: AutoML will optimize for root mean squared error metric, which is tracked as val_root_mean_squared_error in the MLflow experiment. 2023/01/26 03:35:20 INFO databricks.automl.client.manager: MLflow Experiment ID: 2408902617780328 2023/01/26 03:35:20 INFO databricks.automl.client.manager: MLflow Experiment: https://e2-demo-field-eng.cloud.databricks.com/?o=1444828305810485#mlflow/experiments/2408902617780328 2023/01/26 03:41:34 INFO databricks.automl.client.manager: Data exploration notebook: https://e2-demo-field-eng.cloud.databricks.com/?o=1444828305810485#notebook/2408902617780345 2023/01/26 05:37:58 INFO databricks.automl.client.manager: AutoML experiment completed successfully.
Out[25]: <databricks.automl.shared.result.AutoMLSummary at 0x7f67e6936a60>
print_metrics("Hybrid data models")
RMSE: 0.9509656549459966 R^2: 0.7726158074685747

Very nearly identical performance in this case. The synthetic data didn't help or hurt. It's possible that more accurate synthetic data would have a better result in another case.