Python 3 基礎からDeep Learning まで

Python 3の勉強を始めました! 気になる箇所,つまづいて復習しなおしたところなどまとめていきます! 質問・ご意見等,大歓迎です!

Python 3:3次元グラフの書き方(matplotlib, pyplot, mplot3d, MPL)

3次元のグラフの書き方

ニューラルネットワークの学習で現れる,偏微分(勾配)の理解のため3次元グラフを描いてみる.(参考:数値微分
Pythonで,NumPyとmatplotlibを使って3次元グラフを描く

準備

  • 3次元なのでmpl_toolkits.mplot3dなどをインポート
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
  • 引数の2乗和を計算する関数を例として考える
def func1(x, y):
    return x**2 + y**2
  • 描写データの作成
    • 3次元で描写するには2次元メッシュが必要
    • 2次元配列をarangeを用いて作る
    • x, y をそれぞれ1次元領域で分割する
x = np.arange(-3.0, 3.0, 0.1)
y = np.arange(-3.0, 3.0, 0.1)
  • 2次元メッシュはmeshgridでつくる
    • Xの行にxの行列を,Yは列にyの配列を入れたものになっている
X, Y = np.meshgrid(x, y)
Z = func1(X, Y)
  • グラフの作成
    • figureで2次元の図を生成する
    • その後,Axes3D関数で3次元にする
fig = plt.figure()
ax = Axes3D(fig)
  • 軸ラベルの設定
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("f(x, y)")

-グラフ描写

ax.plot_wireframe(X, Y, Z)
plt.show()

以上をまとめて実装すれば,以下のような3次元のグラフが描画されます.


ex1.png


plotのプロパティとして,color(線の色),linestyle(線の種類),マーカーの種類などが変更できる.

もっと簡単に描けるコードなどありましたらご教授ください.
詳細は後ほど追加する予定です.