LIUPENG BLOG
Liupeng
Feb 24, 2020
It takes 3 minutes to read this article.

电影评论分类

代码

import numpy
import tensorflow
import tensorflow_hub
import tensorflow_datasets

train_validation_split = tensorflow_datasets.Split.TRAIN.subsplit([6, 4])

(train_data, validation_data), test_data = tensorflow_datasets.load(
    name='imdb_reviews',
    split=(train_validation_split, tensorflow_datasets.Split.TEST),
    as_supervised=True)


train_examples_batch, train_labels_batch = next(iter(train_data.batch(10)))

# print(train_examples_batch)
# print(train_labels_batch)

#  二分类

embedding = 'https://hub.tensorflow.google.cn/google/tf2-preview/gnews-swivel-20dim/1'
hub_layer = tensorflow_hub.KerasLayer(
    embedding,
    input_shape=[],
    dtype=tensorflow.string,
    trainable=True)

# print(hub_layer(train_examples_batch[:3]))


model = tensorflow.keras.Sequential()
model.add(hub_layer)
model.add(tensorflow.keras.layers.Dense(units=16, activation='relu'))
model.add(tensorflow.keras.layers.Dense(units=1, activation='sigmoid'))

model.summary()

model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

history = model.fit(train_data.shuffle(10000).batch(512),
                    epochs=5,
                    validation_data = validation_data.batch(512),
                    verbose=1)

results = model.evaluate(test_data.batch(512), verbose=2)
for name, value in zip(model.metrics_names, results):
    print("%s: %.3f" % (name, validation_data))