前端AI之路: KerasJS初探
推薦先下載項目,直接運行起來看看效果:
地址:
簡介
Keras是一款非常流行的深度學習模型開發框架,基於python,語法簡潔,封裝程度高,只需十幾行代碼就可以構建一個深度神經網絡。
Keras.js是一個可以在瀏覽器中運行深度神經網絡的JS框架,支持CPU,GPU計算。區別於Keras,Keras.js只能運行已經調試好的模型,無法進行模型訓練。
KerasJS開發流程如下,首先使用Keras開發訓練神經網絡,將神經網絡模型和數據導出為文件,KerasJS在瀏覽器端加載此文件,這樣才能進行預測。
模型
借鑒這篇文章
,開發一個識別聖誕老人的神經網絡。本文不涉及Keras的開發細節,感興趣的同學可以去原文查看。這裡直接給出python代碼
def build_model(): model = models.Sequential() model.add(layers.Conv2D(20,(5,5),activation='relu',input_shape=(128,128,3))) model.add(layers.MaxPooling2D(pool_size=(2,2),strides=(2,2))) model.add(layers.Conv2D(50,(5,5),activation='relu',padding='same')) model.add(layers.MaxPooling2D(pool_size=(2,2),strides=(2,2))) model.add(layers.Flatten()) model.add(layers.Dense(500,activation='relu')) model.add(layers.Dense(1, activation='sigmoid')) model.compile(optimizer=optimizers.RMSprop(lr=2e-5), loss='binary_crossentropy', metrics=['acc']) return model
數據
標註數據是AI模型的原料,數據蒐集特別是圖片蒐集是前端可以介入的一個環節。筆者基於React,開發了一款Chrome圖片批量下載插件GetThemAll,方便我們進行標記圖片蒐集。
安裝好插件後,去谷歌圖片搜索“santa”, 使用插件標記不需要的圖片,然後下載到本地的santa文件夾,通過谷歌圖片可以蒐集到400張聖誕老人的圖片。
接著我們再下載一些非聖誕老人的圖片,搜索“object”,同樣的使用GetThemAll插件下載大約400張圖片到本地的non_santa文件夾中。
除了訓練數據集,我們還需要一個測試數據集用來衡量模型的泛化能力。在本地新建一個test文件夾,把剛剛準備好的訓練集裡面的最後100張聖誕老人圖片移到test文件夾下的santa文件中,同樣的,移動100張非聖誕老人圖片到non_stanta文件中。這樣,你可以得到如下的本地圖片集:
有了標記數據,我們就可以進行模型訓練啦。具體的訓練過程請見pyton代碼,這裡直接給出訓練的結果,藍點表示訓練數據集準確率,藍線表示測試數據集準備率,模型有著明顯的High Variance問題,不過這個bug留給深度學習的專家們解決吧,這裡就假設這個模型可用。
遷移
上一步訓練出的模型keras_santa.h5(h5是文件後綴,和HTML5沒啥關係)不能直接給KerasJS使用,需要通過KerasJS提供的轉換工具轉換後,方可被KerasJS加載解析。
./encoder.py keras_santa.h5
轉換後,得到了keras_santa.bin文件,20M左右,這個文件包含了神經網絡結構和所有參數,可以被KerasJS加載。
KerasJS
通過上面的步驟,我們得到了一個訓練完成的CNN神經網絡以及全部參數,這個網絡結構和參數全部保存在keras_santa.bin文件中。接下來,我們只需要在瀏覽器中復原上面的神經網絡,然後就可以開始做預測啦。
使用webpack配合React,搭建一套簡單的開發環境。做好了基礎工作,就可以開始第一步開發,加載神經網絡模型文件keras_santa.bin:
const model = new KerasJS.Model({ filepath: 'http://localhost:3000/keras_santa.bin', gpu: false }) //KerasJS提供模型加载进度接口,考虑到模型文件体积非常大,这个接口会经常用到model.events.on('loadingProgress', (progress) => { this.setState({ loadingtitle: '模型加载', progress: parseInt(progress) }) })
使用上面的模型做預測前,需要將數據轉化成模型能夠接受的數據格式。這個聖誕老人網絡需要數據輸入格式為(128,128,3),也即是圖片需要為128x128分辨率,只能包含RGB三個分量。
借助canvas,可以實現圖片分辨率轉換:
_updateImageSrc(imgid) { const ctx = this.refs.canvas.getContext('2d'); const imgdom = document.createElement('img'); imgdom.src = `http://localhost:3000/${imgid}.jpeg` this.setState({ prediction:0 }) imgdom.onload = ()=>{ ctx.drawImage(imgdom,0,0,128,128) const imagedata = ctx.getImageData(0,0,128,128) const processeddata = ImageDataUtils.preprocess(imagedata) setTimeout(()=>{ this.doPrediction(processeddata) },100); } }
注意preprocess方法,通過canvas獲取到的圖片資源包含了rgba四個維度,prepross返回這4個維度中的前3個維度,也即rgb,同時將數據標準化:
export default class ImageDataUtils { static preprocess(imageData) { const { width, height, data } = imageData; const dataTensor = ndarray(new Float32Array(data),[width,height,4]) const dataProcessedTensor = ndarray(new Float32Array(width*height*3),[width,height,3]) //从[0,255]转化到[0,1] ops.divseq(dataTensor,255) //获取R数据ops.assign(dataProcessedTensor.pick(null,null,0),dataTensor.pick(null,null,0)) //获取G数据ops.assign(dataProcessedTensor.pick(null,null,1),dataTensor.pick(null,null,1)) //获取B数据ops.assign(dataProcessedTensor.pick(null,null,2),dataTensor.pick(null,null,2)) const preprocessedData = dataProcessedTensor.data return preprocessedData } }
最後,使用上面返回的數據做預測
async doPrediction(imagedata) { if(!this.model) return; const inputname = this.model.inputLayerNames[0] const inputdata = {[inputname]: imagedata} const prediction = await this.model.predict(inputdata) this.setState({ prediction: prediction.output[0] }) }

思考
可以看到,KerasJS在預測過程中,整個頁面無法響應用戶操作。這是因為神經網絡計算過程中佔用了大量CPU資源,從而致使頁面卡頓。下一篇文章中,我們將介紹如何使用WebGL,將計算過程轉移到GPU,達到實現前端高性能計算的目的。
相關資源
- Image classification with Keras and deep learning ,Adrain Rosebrock
- GetThemAll , eeandrew
- React KerasJS , eeandrew