訂閱
糾錯
加入自媒體

更復雜的體系結構能保證更好的模型嗎?

使用的數據集和數據預處理

我們將使用Kaggle的狗與貓數據集。它是根據知識共享許可證授權的,這意味著你可以免費使用它:

該數據集相當大——25000張圖像均勻分布在不同的類中(12500張狗圖像和12500張貓圖像)。它應該足夠大,以訓練一個像樣的圖像分類器。

你還應該刪除train/cat/666.jpg和train/dog/11702.jpg圖像,這些已經損壞,你的模型將無法使用它們進行訓練。

接下來,讓我們看看如何使用TensorFlow加載圖像。

如何使用TensorFlow加載圖像數據

今天你將看到的模型將比前幾篇文章中的模型具有更多的層。

為了可讀性,我們將從TensorFlow中導入單個類。如果你正在跟進,請確保有一個帶有GPU的系統(tǒng),或者至少使用Google Colab。

讓我們把庫的導入放在一邊:

image.png

這是很多,但模型會因此看起來格外干凈。

我們現在將像往常一樣加載圖像數據——使用ImageDataGenerator類。

我們將把圖像矩陣轉換為0–1范圍,使用用三個顏色通道,將所有圖像調整為224x224。出于內存方面的考慮,我們將barch大小降低到32:

image.png

以下是你應該看到的輸出:

讓我們鼓搗第一個模型!

向TensorFlow模型中添加層會有什么不同嗎?

從頭開始編寫卷積模型總是一項棘手的任務。網格搜索最優(yōu)架構是不可行的,因為卷積模型需要很長時間來訓練,而且有太多的參數需要檢查。實際上,你更有可能使用遷移學習。這是我們將在不久的將來探討的主題。

今天,這一切都是為了理解為什么在模型架構上大刀闊斧是不值得的。我們用一個簡單的模型獲得了75%的準確率,所以這是我們必須超越的基線。

模型1-兩個卷積塊

我們將宣布第一個模型在某種程度上類似于VGG體系結構——兩個卷積層,后面是一個池層。濾波器設置如下,第一個塊32個,第二個塊64個。

至于損失和優(yōu)化器,我們將堅持基本原則——分類交叉熵和Adam。數據集中的類是完全平衡的,這意味著我們只需跟蹤準確率即可:

model_1 = tf.keras.Sequential([

   Conv2D(filters=32, kernel_size=(3, 3), input_shape=(224, 224, 3), activation='relu'),

   Conv2D(filters=32, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Flatten(),

   Dense(units=128, activation='relu'),

   Dense(units=2, activation='softmax')

])

model_1.compile(

   loss=categorical_crossentropy,

   optimizer=Adam(),

   metrics=[BinaryAccuracy(name='accuracy')]

model_1_history = model_1.fit(

   train_data,

   validation_data=valid_data,

   epochs=10

以下是經過10個epoch后的訓練結果:

看起來我們的表現并沒有超過基線,因為驗證準確率仍然在75%左右。如果我們再加上一個卷積塊會發(fā)生什么?

模型2-三個卷積塊

我們將保持模型體系結構相同,唯一的區(qū)別是增加了一個包含128個濾波器的卷積塊:

model_2 = Sequential([

   Conv2D(filters=32, kernel_size=(3, 3), input_shape=(224, 224, 3), activation='relu'),

   Conv2D(filters=32, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Conv2D(filters=128, kernel_size=(3, 3), activation='relu'),

   Conv2D(filters=128, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Flatten(),

   Dense(units=128, activation='relu'),

   Dense(units=2, activation='softmax')

])

model_2.compile(

   loss=categorical_crossentropy,

   optimizer=Adam(),

   metrics=[BinaryAccuracy(name='accuracy')]

model_2_history = model_2.fit(

   train_data,

   validation_data=valid_data,

   epochs=10

日志如下:

效果變差了。雖然你可以隨意調整batch大小和學習率,但效果可能仍然不行。第一個架構在我們的數據集上工作得更好,所以讓我們試著繼續(xù)調整一下。

模型3-帶Dropout的卷積塊

第三個模型的架構與第一個模型相同,唯一的區(qū)別是增加了一個全連接層和一個Dropout層。讓我們看看這是否會有所不同:

model_3 = tf.keras.Sequential([

   Conv2D(filters=32, kernel_size=(3, 3), input_shape=(224, 224, 3), activation='relu'),

   Conv2D(filters=32, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Flatten(),

   Dense(units=512, activation='relu'),

   Dropout(rate=0.3),

   Dense(units=128),

   Dense(units=2, activation='softmax')

])

model_3.compile(

   loss=categorical_crossentropy,

   optimizer=Adam(),

   metrics=[BinaryAccuracy(name='accuracy')]

model_3_history = model_3.fit(

   train_data,

   validation_data=valid_data,

   epochs=10

以下是訓練日志:

太可怕了,現在還不到70%!上一篇文章中的簡單架構非常好。反而是數據質量問題限制了模型的預測能力。

結論

這就證明了,更復雜的模型體系結構并不一定會產生性能更好的模型。也許你可以找到一個更適合貓狗數據集的架構,但這可能是徒勞的。

你應該將重點轉移到提高數據集質量上。當然,有20K個訓練圖像,但我們仍然可以增加多樣性。這就是數據增強的用武之地。

感謝閱讀!

       原文標題 : 更復雜的體系結構能保證更好的模型嗎?

聲明: 本文由入駐維科號的作者撰寫,觀點僅代表作者本人,不代表OFweek立場。如有侵權或其他問題,請聯系舉報。

發(fā)表評論

0條評論,0人參與

請輸入評論內容...

請輸入評論/評論長度6~500個字

您提交的評論過于頻繁,請輸入驗證碼繼續(xù)

暫無評論

暫無評論

    掃碼關注公眾號
    OFweek人工智能網
    獲取更多精彩內容
    文章糾錯
    x
    *文字標題:
    *糾錯內容:
    聯系郵箱:
    *驗 證 碼:

    粵公網安備 44030502002758號