Demonstrating Keras using OCR problem on MNIST data


import numpy as np
import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
from joblib import load, dump


# This downloads 52.4 MB of data as of November, 2019.
train_images = mnist.train_images()
train_labels = mnist.train_labels()
test_images = mnist.test_images()
test_labels = mnist.test_labels()


# Dumping the downloaded files for easy loading later
dump(train_images, 'files_1/train_images.joblib')
dump(train_labels, 'files_1/train_labels.joblib')
dump(test_images, 'files_1/test_images.joblib')
dump(test_labels, 'files_1/test_labels.joblib')


# Normalize the images.
train_images = (train_images / 255) - 0.5
test_images = (test_images / 255) - 0.5

# Flatten the images.
train_images = train_images.reshape((-1, 784))
test_images = test_images.reshape((-1, 784))

# Build the model.
model = Sequential([
  Dense(64, activation='relu', input_shape=(784,)),
  Dense(64, activation='relu'),
  Dense(10, activation='softmax'),
])

# Compile the model.
model.compile(
  optimizer='adam',
  loss='categorical_crossentropy',
  metrics=['accuracy'],
)

# Train the model.
model.fit(
  train_images,
  to_categorical(train_labels),
  epochs=5,
  batch_size=32,
)

# Evaluate the model.
model.evaluate(
  test_images,
  to_categorical(test_labels)
)


# Save the model to disk.
model.save_weights('files_1/model.h5')


# Saving the model to MongoDB

from pymongo import MongoClient
import gridfs
client = MongoClient('mongodb://vx5dpfvpz-08:27017/')
db = client['local']
fs = gridfs.GridFS(db)
fileID = fs.put( open(r'files_1\model.h5'.replace('\\', '/'), 'rb'), filename="model.h5")

out = fs.get(fileID)

for grid_out in fs.find({"filename": "model.h5"}, no_cursor_timeout=True):
    print(grid_out.length)

weights_temp = out.read()

f_out_path = r'files_1\model_retrieved.h5'.replace('\\', '/')
with open(f_out_path, 'wb') as f:
    f.write(weights_temp)


# Load the model from disk later using:
model.load_weights(r'files_1\model_retrieved.h5'.replace('\\', '/'))

# Predict on the first 5 test images.
predictions = model.predict(test_images[:5])

# Print our model's predictions.
print(np.argmax(predictions, axis=1)) # [7, 2, 1, 0, 4]

# Check our predictions against the ground truths.
print(test_labels[:5]) # [7, 2, 1, 0, 4]


# Ref: https://victorzhou.com/blog/keras-neural-network-tutorial/

No comments:

Post a Comment