作りながら学ぶ深層強化学習 p.116のエラー

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'other'

→ ランタイムエラー:引数otherにスカラーのlong型オブジェクトを期待したが、Float型が代入された
Pytorchでは、テストデータにlong型を使用すること
そのため、P.112のDataLoaderの作成部分のコード(↓)のTensorに変換すると同時にlong型に変換を行う。

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=1/7, random_state=0)

X_train = torch.Tensor(X_train)
X_test = torch.Tensor(X_test)
y_train = torch.Tensor(y_train)
y_test = torch.Tensor(y_test)

下2行を次のように変える。

y_train = torch.Tensor(y_train).long()
y_test = torch.Tensor(y_test).long()

うまく、学習ができた。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です