祝海林
5 min readJun 16, 2021

--

Introduction to MLSQL deep learning [2] -Distributed model training

All code examples in this article are based on the latest version of MLSQL Engine 2.1.0-SNAPSHOT

This article uses the MLSQL Console notebook to demonstrate the Hello world sample mnist dataset for deep learning.

List of series articles:

  1. MLSQL Machine Learning Minimalist Tutorial (No Python Required!)
  2. MLSQL deep learning introduction [1]
  3. Introduction to MLSQL deep learning [2] -Distributed model training
  4. Introduction to MLSQL deep learning [3] -Feature engineering
  5. Introduction to MLSQL deep learning [4] -Serving

For environmental requirements and data preparation, please refer to the previous article: Introduction to MLSQL deep learning [1].

We will continue to use the dataset of the previous article.

Note that the Ray environment is required in this article

Load data

We can load the data in the notebook of MLSQL Console as follows:

load parquet.`/tmp/mnist` as mnist;

Python environment configuration

Then you need to specify the Python Client on the Driver side and select the corresponding environment. Because the output we get from Python is a model directory, we can specify the schema as a file. For non-ETL, dataMode must be model.

!python env "PYTHON_ENV=source /Users/allwefantasy/opt/anaconda3/bin/activate ray1.3.0";!python conf "runIn=driver";!python conf "schema=file";!python conf "dataMode=model";

Basic concept

Now, we can start writing python code, which looks like this:

In the first step, we still need to provide information such as the language type of the current cell, the input data table name, and whether the cache is needed in the notebook cell. Specifically, the following Annotation:

--%python--%input=mnist--%cache=true--%output=mnist_model

Next, we need to get ray_context session object in the following way:

ray_context = RayContext.connect(globals(),"127.0.0.1:10001")

The ray_context object can help us to get data and output data in Python. Here, we specify that the data is the mnist table, which is the table previously loaded by load statement. We can get references for each data partition by the following code:

data_servers = ray_context.data_servers()

data_servers is an array. The length of the array is the number of partitions.

replica_num = len(data_servers)print(f"total workers {replica_num}")

In the next code, we will launch replica_num TF Workers to train our model in a distributed way.

Model and Data

We will use a full connected network model, and his appearance is like this:

def create_tf_model():network = models.Sequential()network.add(layers.Dense(512,activation="relu",input_shape=(28*28,)))network.add(layers.Dense(10,activation="softmax"))network.compile(optimizer="sgd",loss="categorical_crossentropy",metrics=["accuracy"])return network

Next, we can use RayContext.collect_from to get real data generator from data ref like following code(In this example, we simply collect them in memory for convenience):

def data_partition_creater(data_server):temp_data = [item for item in RayContext.collect_from([data_server])]train_images = np.array([np.array(item["image"]) for item in temp_data])train_labels = np_utils.to_categorical(np.array([item["label"] for item in temp_data])    )train_images = train_images.reshape((len(temp_data),28*28))return train_images,train_labels

In this article, we use Actor(in Ray) to create TF workers to complete distributed training. The number of Actors depends on the number of data partitions.

Create TF Worker

We will use Ray’s Actor to build the Worker, and then use the Parameter Server mode to train.

The code defines Worker as follows:

@ray.remoteclass Network(object):def __init__(self,data_server):self.model = create_tf_model()# you can also save the data to local disk if the data is# not fit in memoryself.train_images,self.train_labels = data_partition_creater(data_server)def train(self):history = self.model.fit(self.train_images,self.train_labels,batch_size=128)return history.historydef get_weights(self):return self.model.get_weights()def set_weights(self, weights):# Note that for simplicity this does not handle the optimizer state.self.model.set_weights(weights)def get_final_model(self):model_path = os.path.join("/","tmp","minist_model")self.model.save(model_path)model_binary = [item for item in streaming_tar.build_rows_from_file(model_path)]return model_binarydef shutdown(self):ray.actor.exit_actor()

Start TF Worker

Now, start Worker based on the number of data shards (note that these workers are all independent processes distributed on the Ray cluster):

workers = [Network.remote(data_server) for data_server in data_servers]

Start training

Start the first Epoch training to get the parameters of the model included by each Worker:

ray.get([worker.train.remote() for worker in workers])_weights = ray.get([worker.get_weights.remote() for worker in workers])

Define a method to update parameters:

def epoch_train(weights):sum_weights = reduce(lambda a,b: [(a1 + b1) for a1,b1 in zip(a,b)],weights)averaged_weights = [layer/replica_num for layer in sum_weights][worker.set_weights.remote(averaged_weights) for worker in workers]ray.get([worker.train.remote() for worker in workers])return ray.get([worker.get_weights.remote() for worker in workers])

Now you can train:

for epoch in range(6):_weights = epoch_train(_weights)

You will see the following log:

Return model

Finally, we randomly select a worker and saves the model:

model_binary = ray.get(workers[0].get_final_model.remote())

Close all workers:

[worker.shutdown.remote() for worker in workers]

Return the model to the system:

ray_context.build_result(model_binary)

Save the model to Data Lake

save overwrite mnist_model as delta.`ai_model.mnist_model`;

The output is as follows:

The complete Notebook is as follows

Full Python Code

--%python--%input=mnist--%output=mnist_model--%cache=truefrom functools import reduceimport osimport rayimport numpy as npfrom tensorflow.keras import models,layersfrom tensorflow.keras import utils as np_utilsfrom pyjava.api.mlsql import RayContextfrom pyjava.storage import streaming_tarray_context = RayContext.connect(globals(),"127.0.0.1:10001")data_servers = ray_context.data_servers()replica_num = len(data_servers)print(f"total workers {replica_num}")def data_partition_creater(data_server):temp_data = [item for item in RayContext.collect_from([data_server])]train_images = np.array([np.array(item["image"]) for item in temp_data])train_labels = np_utils.to_categorical(np.array([item["label"] for item in temp_data])    )train_images = train_images.reshape((len(temp_data),28*28))return train_images,train_labelsdef create_tf_model():network = models.Sequential()network.add(layers.Dense(512,activation="relu",input_shape=(28*28,)))network.add(layers.Dense(10,activation="softmax"))network.compile(optimizer="sgd",loss="categorical_crossentropy",metrics=["accuracy"])return network@ray.remoteclass Network(object):def __init__(self,data_server):self.model = create_tf_model()# you can also save the data to local disk if the data is# not fit in memoryself.train_images,self.train_labels = data_partition_creater(data_server)def train(self):history = self.model.fit(self.train_images,self.train_labels,batch_size=128)return history.historydef get_weights(self):return self.model.get_weights()def set_weights(self, weights):# Note that for simplicity this does not handle the optimizer state.self.model.set_weights(weights)def get_final_model(self):model_path = os.path.join("/","tmp","minist_model")self.model.save(model_path)model_binary = [item for item in streaming_tar.build_rows_from_file(model_path)]return model_binarydef shutdown(self):ray.actor.exit_actor()workers = [Network.remote(data_server) for data_server in data_servers]ray.get([worker.train.remote() for worker in workers])_weights = ray.get([worker.get_weights.remote() for worker in workers])def epoch_train(weights):sum_weights = reduce(lambda a,b: [(a1 + b1) for a1,b1 in zip(a,b)],weights)averaged_weights = [layer/replica_num for layer in sum_weights]ray.get([worker.set_weights.remote(averaged_weights) for worker in workers])ray.get([worker.train.remote() for worker in workers])return ray.get([worker.get_weights.remote() for worker in workers])for epoch in range(6):_weights = epoch_train(_weights)model_binary = ray.get(workers[0].get_final_model.remote())[worker.shutdown.remote() for worker in workers]ray_context.build_result(model_binary)

--

--