
こんにちは、四谷ラボのやましんです。
晴れて結婚20周年を迎えることができました!!
こうやって、書きたいときにブログを書けるのも、周りの人が支えてくれるからだね。
ありがとう!
結果発表
さっそく「GAFA。水面下の争いに決着。」を目撃すべく、この映像を見て欲しい。
後半の追い上げがハンパなく、GAFAの中で優勝を獲得したのはAppleでした。
※仮想環境で文字列の泳ぎを学習させて、競わせたものであり、各社の業績とは無関係です。
今回挑戦した内容
バーチャル空間でGoogle、Apple、Facebookそしてamazonの文字列をくねらせて、水泳の機械学習に挑戦しました。
機械学習の環境
Unity バージョン2020.1.10f1
ml-agents version 1.1.0
https://github.com/Unity-Technologies/ml-agents
文字列を泳がせて学習するアイデア
何はともあれ、文字列を構成する文字を3次元化する。

こんな感じ。これは「よつやらぼ」だよ。わかったよね。
文字を3次元化できるアセットがあったので、これを利用しました。
https://assetstore.unity.com/packages/tools/particles-effects/flyingtext3d-3627
直接関係ないけど、最近3Dプリントが楽しくて、上の文字列を具現化してみた。

次に、3次元化した文字列に浮力を与える。
めちゃくちゃいいのがあった。浮力のシミュレーションがとてもうまく再現できる。

https://github.com/dbrizov/NaughtyWaterBuoyancy
さらに、隣り合う文字同士を関節のように接続する。
Unity configurable jointを使って、隣り合う文字同士を接続することにした。
上から見ると、各間接の可動域が一律で180°であることがわかるね。
それと、文字同士の衝突判定はしないよ。
最後は、こんな感じにできました。
よし!他の文字列も同じ要領で作るぞ。
実装
上記で作った文字列のように、文字間の可動域を一定(最大180°)にした状態で、各関節をどのように運動させれば、大きな推進力が得られるかを機械学習(PPO)で推定する。
機械学習の報酬とペナルティ。
ポイントは、報酬とペナルティ。
試行錯誤した結果が次のもの。
- Time penalty
時間が経過すると共にマイナス(ペナルティ)
- Velocity Rewards
文字列のゴール(プール左端)に対する相対速度が速ければ報酬。
- Direction penalty
文字列がゴールに向いていれば報酬。
- Course out penalty
文字列がレーンからずれたらペナルティ
- Reached target
文字列がゴール到達すれば報酬
- Fell off platform
文字列がプラットフォームから落ちるとペナルティ
プログラム
このプログラムは、最も重要な ml-agents の Agentクラスを継承したクラスです。 今回、有料アセット(文字列3次元化、小学校プール)を使用しているので、githubには共有しません。 (Unity有料アセットを含むプログラムのオープンソース化(バイナリ化?)について、知っている人教えて下さい。)
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using FlyingText3D;
using NaughtyWaterBuoyancy;
using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine.UI;
public class SwimmingAgent : Agent
{
[Header("Specific to SwimmingAgent")]
private List<string> colors = new List<string>{"white", "black", "red","blue","yellow"};
public string SwimingTextString;
public string SwimingTextColor;
public int course_num;
private List<ConfigurableJoint> joint_list = new List<ConfigurableJoint>();
private List<Vector3> current_normalized_rot = new List<Vector3>();
private List<GameObject> Text3Ds = new List<GameObject>();
private Rigidbody initialCharacter = null;
public float maxJointSpring;
public float jointDampen;
public float maxJointForceLimit;
public rerowrds total;
public Text totalText;
public Text totalTotalText;
public rerowrds time;
public Text timeText;
public Text timeTotalText;
public rerowrds velocity;
public Text velocityText;
public Text velocityTotalText;
public rerowrds direction;
public Text directionText;
public Text directionTotalText;
public rerowrds touch_target;
public Text touch_targetText;
public Text touch_targetTotalText;
public rerowrds course_out;
public Text course_outText;
public Text course_outTotalText;
Vector3 m_DirToTarget;
//private static FlyingText instance;Vector3 m_DirToTarget
Quaternion m_LookRotation;
Matrix4x4 m_TargetDirMatrix;
//private int color_index = 0;
GameObject CreateFlyingText(String text, float pos)
{
FlyingText.colliderType = ColliderType.ConvexMesh;
FlyingText.defaultSize = 0.8f;
FlyingText.defaultDepth = 0.1f;
GameObject objectParent = FlyingText.GetObjects("<color=" + SwimingTextColor +">" + text) ;
// GameObject objectParent2 = FlyingText.GetObjects(text);
GameObject charPolygon = objectParent.transform.GetChild(0).gameObject;
FloatingObject fo = charPolygon.AddComponent<FloatingObject>() as FloatingObject;
fo.dragInWater = 0.0001f;
fo.angularDragInWater = 0.0001f;
//fo.density = 0.75f;
//Vector3 org_pos = new Vector3(9.5f, 1.0f, 0.0f);
Vector3 org_pos = new Vector3(0.0f, 1.0f, 0.0f);
//objectParent.transform.rotation = new Quaternion(90.0f, -90.0f, -90.0f, 0.0f);
objectParent.transform.position = new Vector3(org_pos.x + pos, org_pos.y, org_pos.z);
return charPolygon;
}
public override void Initialize()
{
CreateText3Ds();
SetResetParameters();
}
public override void CollectObservations(VectorSensor sensor)
{
for (int i = 0; i < Text3Ds.Count(); i++)
{
GameObject charPolygon = Text3Ds[i];
/*
sensor.AddObservation(charPolygon.transform.rotation);
sensor.AddObservation(charPolygon.transform.position);
*/
Rigidbody rb = charPolygon.GetComponent<Rigidbody>();
sensor.AddObservation(rb.velocity);
sensor.AddObservation(rb.transform.position);
sensor.AddObservation(rb.transform.rotation);
sensor.AddObservation(rb.centerOfMass);
}
for(int i = 0;i < joint_list.Count();i++)
{
sensor.AddObservation(current_normalized_rot[i]);
}
}
// public float speed = 10;
Vector3 PrevLocation;
public override void OnActionReceived(float[] vectorAction)
{
int idx = 0;
for(int i = 0;i < joint_list.Count();i++)
{
SetJointTargetRotation(i,0.0f, vectorAction[idx++], 0.0f);
//SetJointStrength(joint_list[i],vectorAction[idx++]);
SetJointStrength(joint_list[i],1.0f);
}
for (int i = 0; i < Text3Ds.Count(); i++)
{
Rigidbody rb = Text3Ds[i].GetComponent<Rigidbody>();
rb.centerOfMass = new Vector3(0f, (vectorAction[idx++]-0.5f), 0f);
}
CalcRewards();
}
public struct rerowrds
{
public Boolean is_use;
public float total;
public float current_value;
public float rate;
public Boolean is_penalty;
};
void SetScores()
{
if (timeText)
{
timeText.text = time.current_value.ToString();
timeTotalText.text = time.total.ToString();
velocityText.text = velocity.current_value.ToString();
velocityTotalText.text = velocity.total.ToString();
directionText.text = direction.current_value.ToString();
directionTotalText.text = direction.total.ToString();
course_outText.text = course_out.current_value.ToString();
course_outTotalText.text = course_out.total.ToString();
touch_targetText.text = touch_target.current_value.ToString();
touch_targetTotalText.text = touch_target.total.ToString();
totalText.text = total.current_value.ToString();
totalTotalText.text = total.total.ToString();
}
}
void CalcRewards()
{
//Time penalty
if(time.is_use)
{
time.current_value = time.rate * (time.is_penalty ? -1:1);
AddReward(time.current_value);
time.total += time.current_value;
}
//Velocity Rewards
if (velocity.is_use)
{
velocity.current_value = -initialCharacter.velocity.x * velocity.rate * (velocity.is_penalty ? -1:1);
AddReward(velocity.current_value);
velocity.total += velocity.current_value;
}
//Direction penalty
if(direction.is_use)
{
direction.current_value = (.5f - Math.Abs(initialCharacter.transform.rotation.y)) * direction.rate * (direction.is_penalty ? -1:1);
AddReward(direction.current_value);
direction.total += direction.current_value;
}
//Course out penalty
if(course_out.is_use)
{
course_out.current_value = Math.Abs(initialCharacter.transform.position.z+(5.0f - course_num * 2.5f)) * course_out.rate * (course_out.is_penalty ? -1:1);
AddReward(course_out.current_value);
course_out.total += course_out.current_value;
/*
course_out.current_value = Math.Abs(initialCharacter.velocity.z) * course_out.rate * (course_out.is_penalty ? -1:1);
AddReward(course_out.current_value);
course_out.total += course_out.current_value;
*/
}
// Reached target
if(touch_target.is_use)
{
float distanceToTarget = initialCharacter.transform.position.x - Target.position.x;
if (distanceToTarget < 2f)
{
touch_target.current_value = touch_target.rate * (touch_target.is_penalty ? -1:1);
AddReward(touch_target.current_value);
touch_target.total += touch_target.current_value;
EndEpisode();
}
}
touch_target.current_value = 0;
// Fell off platform
if (initialCharacter.transform.position.y < -20)
{
EndEpisode();
}
total.current_value = time.total + velocity.total + direction.total + course_out.total + touch_target.total;
total.total += total.current_value;
SetScores();
}
public Transform Target;
public override void OnEpisodeBegin()
{
SetResetParameters();
}
public override void Heuristic(float[] actionsOut)
{
actionsOut[0] = -Input.GetAxis("Horizontal");
actionsOut[1] = Input.GetAxis("Vertical");
}
//int a=0;
private void ResetText3Ds()
{
float interval = 0.0f;
float current_pos = 0.0f;
SetConfigrableJoint();
for (int i = 0; i < Text3Ds.Count(); i++)
{
GameObject charPolygon = Text3Ds[i];
//Vector3 org_pos = new Vector3(9.5f, -9.0f, - 5.0f + course_num * 2.5f );
Vector3 org_pos = new Vector3(0.0f, -9.0f, - 5.0f + course_num * 2.5f );
charPolygon.transform.position = new Vector3(org_pos.x + current_pos, org_pos.y, org_pos.z);
charPolygon.transform.rotation = new Quaternion(0.0f,0.0f,0.0f,0.0f);
MeshCollider mc = charPolygon.GetComponent<MeshCollider>();
//mc.isTrigger = true;
float width = mc.bounds.size.x;
current_pos += width + interval;
}
}
private void SetConfigrableJoint()
{
for (int i = 0; i < joint_list.Count(); i++)
{
joint_list[i].targetRotation = new Quaternion(0.0f, 0.0f, 0.0f, 0.0f);
current_normalized_rot[i] = new Vector3(0.0f,0.0f,0.0f);
}
}
GameObject pre_charPolygon = null;
private void CreateConfigrableJoint(GameObject charPolygon)
{
if (pre_charPolygon)
{
float height = charPolygon.GetComponent<MeshCollider>().bounds.size.y;
float depth = charPolygon.GetComponent<MeshCollider>().bounds.size.z;
ConfigurableJoint cj = charPolygon.AddComponent<ConfigurableJoint>() as ConfigurableJoint;
joint_list.Add(cj);
current_normalized_rot.Add( new Vector3(0.0f, 0.0f, 0.0f));
cj.xMotion = ConfigurableJointMotion.Locked;
cj.yMotion = ConfigurableJointMotion.Locked;
cj.zMotion = ConfigurableJointMotion.Locked;
cj.angularXMotion = ConfigurableJointMotion.Locked;
cj.angularYMotion = ConfigurableJointMotion.Limited;
cj.angularYLimit = new SoftJointLimit() { limit = 90.0f };
cj.angularZMotion = ConfigurableJointMotion.Locked;
/**
cj.angularZMotion = ConfigurableJointMotion.Limited;
cj.angularZLimit = new SoftJointLimit() { limit = 90.0f };
*/
cj.anchor = new Vector3(0.0f, height / 2.0f, depth/2.0f);
cj.axis = new Vector3(0.0f, 0.0f, -1.0f);
cj.massScale = 1000;
cj.connectedMassScale = 1000;
cj.enableCollision = false;
//ConfigurableJoint joint = charPolygon.GetComponent<ConfigurableJoint>() as ConfigurableJoint;
Rigidbody rb = pre_charPolygon.GetComponent<Rigidbody>();
cj.connectedBody = rb;
}
pre_charPolygon = charPolygon;
}
private void CreateText3Ds()
{
float interval = 0.0f;
float current_pos = 0.0f;
char[] c = SwimingTextString.ToCharArray();
//Rigidbody pre_Rigidbody = null;
for (int i = 0; i < c.Length; i++)
{
string s = new String(new char[] { c[i] });
GameObject charPolygon = CreateFlyingText(s, current_pos);
// color_index ++;
charPolygon.layer = 8;
Text3Ds.Add(charPolygon);
MeshCollider mc = charPolygon.GetComponent<MeshCollider>();
float width = mc.bounds.size.x;
current_pos += width + interval;
CreateConfigrableJoint(charPolygon);
Rigidbody rb = charPolygon.GetComponent<Rigidbody>();
rb.centerOfMass = new Vector3(0f, -0.5f, 0f);
if(!initialCharacter)
{
initialCharacter = rb;
//initialCharacter.isKinematic = true;
}
}
}
public void ResetRewords()
{
total.total = 0.0f;
total.current_value = 0.0f;
total.is_penalty = false;
total.rate = 1.0f;
touch_target.is_use =true;
touch_target.total = 0.0f;
touch_target.current_value = 0.0f;
touch_target.is_penalty = false;
touch_target.rate = 1000.0f;
velocity.is_use =true;
velocity.total = 0.0f;
velocity.current_value = 0.0f;
velocity.is_penalty = false;
velocity.rate = 1.0f;
direction.is_use =true;
direction.total = 0.0f;
direction.current_value = 0.0f;
direction.is_penalty = false;
direction.rate = 1.0f;
time.is_use =true;
time.total = 0.0f;
time.current_value = 0.0f;
time.is_penalty = true;
time.rate = 1.0f;
course_out.is_use =true;
course_out.total = 0.0f;
course_out.current_value = 0.0f;
course_out.is_penalty = true;
course_out.rate = .1f;
}
public void SetResetParameters()
{
ResetRewords();
ResetText3Ds();
}
public void SetJointTargetRotation(int i, float x, float y, float z)
{
ConfigurableJoint joint = joint_list[i];
x = (x + 1f) * 0.5f;
y = (y + 1f) * 0.5f;
z = (z + 1f) * 0.5f;
var xRot = Mathf.Lerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, x);
var yRot = Mathf.Lerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, y);
var zRot = Mathf.Lerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, z);
current_normalized_rot[i] = new Vector3(Mathf.InverseLerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, xRot),
Mathf.InverseLerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, yRot),
Mathf.InverseLerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, zRot));
joint.targetRotation = Quaternion.Euler(xRot, yRot, zRot);
}
public void SetJointStrength(ConfigurableJoint joint, float strength)
{
var rawVal = (strength + 1f) * 0.5f * maxJointForceLimit;
var jd = new JointDrive
{
positionSpring = maxJointSpring,
positionDamper = jointDampen,
maximumForce = rawVal
};
joint.slerpDrive = jd;
//currentStrength = jd.maximumForce;
}
}
機械学習時間の文字列の泳ぎの変化
GAFA水泳トレーニング開始直後
みんな、がんばれー。終わったら、はちみつレモンあるわよ。
機械学習2時間経過
ゴールは左側なんだけどなぁ。
機械学習5時間経過
全体的にゴールのある左側へ向かっているのがわかる。
ちなみに、オリンピック水泳選手は、小学生でも1日4,5時間泳ぐこともあるらしい。その辺にいる両生類より水に長く入ってるんちゃう?いい意味で。
機械学習42時間経過
まっすぐ、泳げる個体が増えてきた。
機械学習72時間経過
42時間と比較しても、速度が上がっているのがわかる。
機械学習状況のグラフ(TensorBoard)
Environment

Losses

Policy 1

Policy 2

Policy 3

ちなみにプールの素材はこちらです。
https://assetstore.unity.com/packages/3d/environments/japanese-school-swimming-pool-20487
これらの記事の続編を書くので、また見てね。
20年間ありがとう。これからもよろしく。
おわり