画像系機械学習プロダクトで手元のスクリプトとデプロイ環境の精度が違うときのチェックリスト
はじめに
みなさんこんにちは、いかがお過ごしでしょうか。
機械学習もだいぶ広まってきて、右も左も無限に強い人ばかりで、はやく3億円稼いですべてから解放されて虹コン大好きクラブとして生きていきたい今日この頃です。
機械学習がだいぶ広まってくると、機械学習モデルをデプロイする人も増えてきていると思います。
しかし、機械学習モデルのデプロイは非常に複雑です。
モデル、依存ファイル、前処理/後処理、ライブラリ、その他を実験スクリプトと揃えながら、リクエストに対してレスポンスを返すサーバーを継続して稼働させる必要があるためです。
さらには複数人で開発する場合、コードすべてを把握することが難しい、といった困難も付随します。
すると当然、「手元で出ていた精度と、デプロイ環境にリクエストを投げたときの精度が違う」といったバグが生じます。
本日は、そんなときにバグ原因を見つける方針を記事にしたいと思います。
なぜなら、今日まさにこのバグ発見作業で半日溶かしたからです。
それなりに機械学習と触れ合ってきた私でも苦しんだので、チェックリスト的なものがあると便利かと思いました。
もっとよい方法があったり間違ったことを言っていたらコメントで教えてください。
状況の前提
- 機械学習モデルを作ったり、その精度を測るスクリプトが手元にある
- 対象のモデルは画像系で、Deep Learningを用いている。
- 上記スクリプトを用いて作ったモデルが内蔵された、HTTPリクエストなどを受け付けて、推論結果をレスポンスで返すサーバがある
- その2つに同じ画像を投げているのに、返ってくる結果が違うので困っている
チェックリスト
そもそも本当に同じ状態で精度計測しているかの確認
モデルが同一であるか確認する
モデルが違うと、当然出力も異なります。そのため、まずはここから疑うのがよいと考えています。
デプロイ環境で参照しているモデルを手元に落としてきて確認するのが最速かと思います。
ここで、onnxやtensorflowのsaved modelのような、ネットワーク定義とモデルパラメータがセットで保存されるような形式を使用しているとかなりデバッグがはやくなります。
入力画像が同一であるか確認する
全然違う画像を推論させていると当然全然違う結果になりますが、この辺を疑うことから始めておくと悲しみが防げるかと思います。
画像が同一であるかを簡易的にチェックするためには、hashlibを使うことが多いです。適当にprintして、一致しているかどうかを目視確認、くらいでいいかと思います。
例えばこんな感じで。
from PIL import Image import numpy as np import hashlib def get_hash(image_path): image_array = np.asarray(Image.open(image_path)) hash = hashlib.sha256(image_array.tostring()).hexdigest() return hash
依存ファイルが同一であるか確認する
「依存ファイル」の表現は広いのですが、例えばクラスのindexとクラス名のマッピングを記述したファイルなどが挙げられます。
(index0が犬でindex1が猫で、、のようなファイル)
このあたりもチェックするとよいと考えています。
そもそも推論に再現性があるのかの確認
手元の推論を複数回行ってみる
たまに盲点になるのですが、何らかのランダム要素が入って再現性が無いケースがたまにあります。
モデルに入るまでの確認
HTTPリクエストのときに圧縮したりしていないか
JPEGは非可逆圧縮のため、リクエストにJPEG形式を用いていると画素値が微妙に異なってしまうことがあります。
これもhashで確認してみましょう。
前処理が同一であるか
通常、生画像をそのままネットワークに入力することはありません。
- 正方形になるようにリサイズ / padding
- 255で割ったり平均を引いたりして正規化
のような操作をした後、ネットワークに入力します。
この処理に違いがあると当然異なる結果になるので、確認します。
処理を追うのも良いですが、最も簡単なのはネットワークに入る直前の画像のhashを比較することです。
hashが違ったら、前処理のどこが違うのかをチェックします。観点として抜けやすいのは以下辺りでしょうか。
- 処理の順番
- リサイズのアルゴリズム
モデルを出てからの確認
モデルの出力が同一であるか
- 出力のshape
- 出力のlogit値
などをプリントして確認します。
後処理が同一であるか
前処理同様に後処理についてもチェックします。
モデル内部の挙動の確認
上記に従って後処理まで追うと、
- バグ原因がわかった!
- モデル内部に謎のバグ原因がある...
のどちらかになるのではないかと思います。 モデル内部をもうちょっと頑張って見てみましょう。
trainingモードで推論していないかチェック
手元かデプロイ先のどちらかで、推論時にtrainingモードを用いていないかチェックしてみて下さい。 特に推論用にexportされた形式を取っていない場合に起きやすいかと思います。 pytorchだと、trainingモードかevalモードかをmodel.eval()などで設定します。
具体的には、trainingモードとevalモードだと一般に以下のような違いがあります。(フレームワーク毎に色々あるかもしれないので適宜調査してみて下さい。)
- batch normalization: training時はミニバッチの平均や分散を用い、eval時はtraining時に計算された移動平均の値を用いる
- dropout: training時は確率的にdropoutし、eval時はしない
バッチサイズを変えて出力をチェック
↑のbatch normalizationもそうですが、バッチサイズによって出力が異なるかどうか確認するのもよいかと思います。
本日の私のケースは少々特殊なケースであったため、この推論時のバッチサイズが手元とデプロイ先で異なることに起因して、片方で潜在的なバグが発動していたケースでした。
ライブラリやデバイス等のチェック
ライブラリのバージョンやデバイス(CPU/GPUなど)による差異を確認することで解決するケースもあるかと思います。
上記をやって、それでもだめだったら
ごめんなさい。遭遇/解決したらコメントで教えて頂けると嬉しいです。
おわりに
以上です。色々な機械学習プロダクトをPoCで終わらせずに運用まで持っていって、こういう知見をつけまくって、実務グランドマスターになりたいですね。それでは。