こんにちは、四谷ラボのやましんです。
晴れて結婚20周年を迎えることができました!!
こうやって、書きたいときにブログを書けるのも、周りの人が支えてくれるからだね。
ありがとう!
結果発表
さっそく「GAFA。水面下の争いに決着。」を目撃すべく、この映像を見て欲しい。
後半の追い上げがハンパなく、GAFAの中で優勝を獲得したのはAppleでした。
※仮想環境で文字列の泳ぎを学習させて、競わせたものであり、各社の業績とは無関係です。
今回挑戦した内容
バーチャル空間でGoogle、Apple、Facebookそしてamazonの文字列をくねらせて、水泳の機械学習に挑戦しました。
機械学習の環境
Unity バージョン2020.1.10f1
ml-agents version 1.1.0
https://github.com/Unity-Technologies/ml-agents
文字列を泳がせて学習するアイデア
何はともあれ、文字列を構成する文字を3次元化する。
こんな感じ。これは「よつやらぼ」だよ。わかったよね。
文字を3次元化できるアセットがあったので、これを利用しました。
https://assetstore.unity.com/packages/tools/particles-effects/flyingtext3d-3627
直接関係ないけど、最近3Dプリントが楽しくて、上の文字列を具現化してみた。
次に、3次元化した文字列に浮力を与える。
めちゃくちゃいいのがあった。浮力のシミュレーションがとてもうまく再現できる。
https://github.com/dbrizov/NaughtyWaterBuoyancy
さらに、隣り合う文字同士を関節のように接続する。
Unity configurable jointを使って、隣り合う文字同士を接続することにした。
上から見ると、各間接の可動域が一律で180°であることがわかるね。
それと、文字同士の衝突判定はしないよ。
最後は、こんな感じにできました。
よし!他の文字列も同じ要領で作るぞ。
実装
上記で作った文字列のように、文字間の可動域を一定(最大180°)にした状態で、各関節をどのように運動させれば、大きな推進力が得られるかを機械学習(PPO)で推定する。
機械学習の報酬とペナルティ。
ポイントは、報酬とペナルティ。
試行錯誤した結果が次のもの。
- Time penalty
時間が経過すると共にマイナス(ペナルティ)
- Velocity Rewards
文字列のゴール(プール左端)に対する相対速度が速ければ報酬。
- Direction penalty
文字列がゴールに向いていれば報酬。
- Course out penalty
文字列がレーンからずれたらペナルティ
- Reached target
文字列がゴール到達すれば報酬
- Fell off platform
文字列がプラットフォームから落ちるとペナルティ
プログラム
このプログラムは、最も重要な ml-agents の Agentクラスを継承したクラスです。 今回、有料アセット(文字列3次元化、小学校プール)を使用しているので、githubには共有しません。 (Unity有料アセットを含むプログラムのオープンソース化(バイナリ化?)について、知っている人教えて下さい。)
using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; using FlyingText3D; using NaughtyWaterBuoyancy; using System; using System.Collections.Generic; using System.Linq; using UnityEngine.UI; public class SwimmingAgent : Agent { [Header("Specific to SwimmingAgent")] private List<string> colors = new List<string>{"white", "black", "red","blue","yellow"}; public string SwimingTextString; public string SwimingTextColor; public int course_num; private List<ConfigurableJoint> joint_list = new List<ConfigurableJoint>(); private List<Vector3> current_normalized_rot = new List<Vector3>(); private List<GameObject> Text3Ds = new List<GameObject>(); private Rigidbody initialCharacter = null; public float maxJointSpring; public float jointDampen; public float maxJointForceLimit; public rerowrds total; public Text totalText; public Text totalTotalText; public rerowrds time; public Text timeText; public Text timeTotalText; public rerowrds velocity; public Text velocityText; public Text velocityTotalText; public rerowrds direction; public Text directionText; public Text directionTotalText; public rerowrds touch_target; public Text touch_targetText; public Text touch_targetTotalText; public rerowrds course_out; public Text course_outText; public Text course_outTotalText; Vector3 m_DirToTarget; //private static FlyingText instance;Vector3 m_DirToTarget Quaternion m_LookRotation; Matrix4x4 m_TargetDirMatrix; //private int color_index = 0; GameObject CreateFlyingText(String text, float pos) { FlyingText.colliderType = ColliderType.ConvexMesh; FlyingText.defaultSize = 0.8f; FlyingText.defaultDepth = 0.1f; GameObject objectParent = FlyingText.GetObjects("<color=" + SwimingTextColor +">" + text) ; // GameObject objectParent2 = FlyingText.GetObjects(text); GameObject charPolygon = objectParent.transform.GetChild(0).gameObject; FloatingObject fo = charPolygon.AddComponent<FloatingObject>() as FloatingObject; fo.dragInWater = 0.0001f; fo.angularDragInWater = 0.0001f; //fo.density = 0.75f; //Vector3 org_pos = new Vector3(9.5f, 1.0f, 0.0f); Vector3 org_pos = new Vector3(0.0f, 1.0f, 0.0f); //objectParent.transform.rotation = new Quaternion(90.0f, -90.0f, -90.0f, 0.0f); objectParent.transform.position = new Vector3(org_pos.x + pos, org_pos.y, org_pos.z); return charPolygon; } public override void Initialize() { CreateText3Ds(); SetResetParameters(); } public override void CollectObservations(VectorSensor sensor) { for (int i = 0; i < Text3Ds.Count(); i++) { GameObject charPolygon = Text3Ds[i]; /* sensor.AddObservation(charPolygon.transform.rotation); sensor.AddObservation(charPolygon.transform.position); */ Rigidbody rb = charPolygon.GetComponent<Rigidbody>(); sensor.AddObservation(rb.velocity); sensor.AddObservation(rb.transform.position); sensor.AddObservation(rb.transform.rotation); sensor.AddObservation(rb.centerOfMass); } for(int i = 0;i < joint_list.Count();i++) { sensor.AddObservation(current_normalized_rot[i]); } } // public float speed = 10; Vector3 PrevLocation; public override void OnActionReceived(float[] vectorAction) { int idx = 0; for(int i = 0;i < joint_list.Count();i++) { SetJointTargetRotation(i,0.0f, vectorAction[idx++], 0.0f); //SetJointStrength(joint_list[i],vectorAction[idx++]); SetJointStrength(joint_list[i],1.0f); } for (int i = 0; i < Text3Ds.Count(); i++) { Rigidbody rb = Text3Ds[i].GetComponent<Rigidbody>(); rb.centerOfMass = new Vector3(0f, (vectorAction[idx++]-0.5f), 0f); } CalcRewards(); } public struct rerowrds { public Boolean is_use; public float total; public float current_value; public float rate; public Boolean is_penalty; }; void SetScores() { if (timeText) { timeText.text = time.current_value.ToString(); timeTotalText.text = time.total.ToString(); velocityText.text = velocity.current_value.ToString(); velocityTotalText.text = velocity.total.ToString(); directionText.text = direction.current_value.ToString(); directionTotalText.text = direction.total.ToString(); course_outText.text = course_out.current_value.ToString(); course_outTotalText.text = course_out.total.ToString(); touch_targetText.text = touch_target.current_value.ToString(); touch_targetTotalText.text = touch_target.total.ToString(); totalText.text = total.current_value.ToString(); totalTotalText.text = total.total.ToString(); } } void CalcRewards() { //Time penalty if(time.is_use) { time.current_value = time.rate * (time.is_penalty ? -1:1); AddReward(time.current_value); time.total += time.current_value; } //Velocity Rewards if (velocity.is_use) { velocity.current_value = -initialCharacter.velocity.x * velocity.rate * (velocity.is_penalty ? -1:1); AddReward(velocity.current_value); velocity.total += velocity.current_value; } //Direction penalty if(direction.is_use) { direction.current_value = (.5f - Math.Abs(initialCharacter.transform.rotation.y)) * direction.rate * (direction.is_penalty ? -1:1); AddReward(direction.current_value); direction.total += direction.current_value; } //Course out penalty if(course_out.is_use) { course_out.current_value = Math.Abs(initialCharacter.transform.position.z+(5.0f - course_num * 2.5f)) * course_out.rate * (course_out.is_penalty ? -1:1); AddReward(course_out.current_value); course_out.total += course_out.current_value; /* course_out.current_value = Math.Abs(initialCharacter.velocity.z) * course_out.rate * (course_out.is_penalty ? -1:1); AddReward(course_out.current_value); course_out.total += course_out.current_value; */ } // Reached target if(touch_target.is_use) { float distanceToTarget = initialCharacter.transform.position.x - Target.position.x; if (distanceToTarget < 2f) { touch_target.current_value = touch_target.rate * (touch_target.is_penalty ? -1:1); AddReward(touch_target.current_value); touch_target.total += touch_target.current_value; EndEpisode(); } } touch_target.current_value = 0; // Fell off platform if (initialCharacter.transform.position.y < -20) { EndEpisode(); } total.current_value = time.total + velocity.total + direction.total + course_out.total + touch_target.total; total.total += total.current_value; SetScores(); } public Transform Target; public override void OnEpisodeBegin() { SetResetParameters(); } public override void Heuristic(float[] actionsOut) { actionsOut[0] = -Input.GetAxis("Horizontal"); actionsOut[1] = Input.GetAxis("Vertical"); } //int a=0; private void ResetText3Ds() { float interval = 0.0f; float current_pos = 0.0f; SetConfigrableJoint(); for (int i = 0; i < Text3Ds.Count(); i++) { GameObject charPolygon = Text3Ds[i]; //Vector3 org_pos = new Vector3(9.5f, -9.0f, - 5.0f + course_num * 2.5f ); Vector3 org_pos = new Vector3(0.0f, -9.0f, - 5.0f + course_num * 2.5f ); charPolygon.transform.position = new Vector3(org_pos.x + current_pos, org_pos.y, org_pos.z); charPolygon.transform.rotation = new Quaternion(0.0f,0.0f,0.0f,0.0f); MeshCollider mc = charPolygon.GetComponent<MeshCollider>(); //mc.isTrigger = true; float width = mc.bounds.size.x; current_pos += width + interval; } } private void SetConfigrableJoint() { for (int i = 0; i < joint_list.Count(); i++) { joint_list[i].targetRotation = new Quaternion(0.0f, 0.0f, 0.0f, 0.0f); current_normalized_rot[i] = new Vector3(0.0f,0.0f,0.0f); } } GameObject pre_charPolygon = null; private void CreateConfigrableJoint(GameObject charPolygon) { if (pre_charPolygon) { float height = charPolygon.GetComponent<MeshCollider>().bounds.size.y; float depth = charPolygon.GetComponent<MeshCollider>().bounds.size.z; ConfigurableJoint cj = charPolygon.AddComponent<ConfigurableJoint>() as ConfigurableJoint; joint_list.Add(cj); current_normalized_rot.Add( new Vector3(0.0f, 0.0f, 0.0f)); cj.xMotion = ConfigurableJointMotion.Locked; cj.yMotion = ConfigurableJointMotion.Locked; cj.zMotion = ConfigurableJointMotion.Locked; cj.angularXMotion = ConfigurableJointMotion.Locked; cj.angularYMotion = ConfigurableJointMotion.Limited; cj.angularYLimit = new SoftJointLimit() { limit = 90.0f }; cj.angularZMotion = ConfigurableJointMotion.Locked; /** cj.angularZMotion = ConfigurableJointMotion.Limited; cj.angularZLimit = new SoftJointLimit() { limit = 90.0f }; */ cj.anchor = new Vector3(0.0f, height / 2.0f, depth/2.0f); cj.axis = new Vector3(0.0f, 0.0f, -1.0f); cj.massScale = 1000; cj.connectedMassScale = 1000; cj.enableCollision = false; //ConfigurableJoint joint = charPolygon.GetComponent<ConfigurableJoint>() as ConfigurableJoint; Rigidbody rb = pre_charPolygon.GetComponent<Rigidbody>(); cj.connectedBody = rb; } pre_charPolygon = charPolygon; } private void CreateText3Ds() { float interval = 0.0f; float current_pos = 0.0f; char[] c = SwimingTextString.ToCharArray(); //Rigidbody pre_Rigidbody = null; for (int i = 0; i < c.Length; i++) { string s = new String(new char[] { c[i] }); GameObject charPolygon = CreateFlyingText(s, current_pos); // color_index ++; charPolygon.layer = 8; Text3Ds.Add(charPolygon); MeshCollider mc = charPolygon.GetComponent<MeshCollider>(); float width = mc.bounds.size.x; current_pos += width + interval; CreateConfigrableJoint(charPolygon); Rigidbody rb = charPolygon.GetComponent<Rigidbody>(); rb.centerOfMass = new Vector3(0f, -0.5f, 0f); if(!initialCharacter) { initialCharacter = rb; //initialCharacter.isKinematic = true; } } } public void ResetRewords() { total.total = 0.0f; total.current_value = 0.0f; total.is_penalty = false; total.rate = 1.0f; touch_target.is_use =true; touch_target.total = 0.0f; touch_target.current_value = 0.0f; touch_target.is_penalty = false; touch_target.rate = 1000.0f; velocity.is_use =true; velocity.total = 0.0f; velocity.current_value = 0.0f; velocity.is_penalty = false; velocity.rate = 1.0f; direction.is_use =true; direction.total = 0.0f; direction.current_value = 0.0f; direction.is_penalty = false; direction.rate = 1.0f; time.is_use =true; time.total = 0.0f; time.current_value = 0.0f; time.is_penalty = true; time.rate = 1.0f; course_out.is_use =true; course_out.total = 0.0f; course_out.current_value = 0.0f; course_out.is_penalty = true; course_out.rate = .1f; } public void SetResetParameters() { ResetRewords(); ResetText3Ds(); } public void SetJointTargetRotation(int i, float x, float y, float z) { ConfigurableJoint joint = joint_list[i]; x = (x + 1f) * 0.5f; y = (y + 1f) * 0.5f; z = (z + 1f) * 0.5f; var xRot = Mathf.Lerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, x); var yRot = Mathf.Lerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, y); var zRot = Mathf.Lerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, z); current_normalized_rot[i] = new Vector3(Mathf.InverseLerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, xRot), Mathf.InverseLerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, yRot), Mathf.InverseLerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, zRot)); joint.targetRotation = Quaternion.Euler(xRot, yRot, zRot); } public void SetJointStrength(ConfigurableJoint joint, float strength) { var rawVal = (strength + 1f) * 0.5f * maxJointForceLimit; var jd = new JointDrive { positionSpring = maxJointSpring, positionDamper = jointDampen, maximumForce = rawVal }; joint.slerpDrive = jd; //currentStrength = jd.maximumForce; } }
機械学習時間の文字列の泳ぎの変化
GAFA水泳トレーニング開始直後
みんな、がんばれー。終わったら、はちみつレモンあるわよ。
機械学習2時間経過
ゴールは左側なんだけどなぁ。
機械学習5時間経過
全体的にゴールのある左側へ向かっているのがわかる。
ちなみに、オリンピック水泳選手は、小学生でも1日4,5時間泳ぐこともあるらしい。その辺にいる両生類より水に長く入ってるんちゃう?いい意味で。
機械学習42時間経過
まっすぐ、泳げる個体が増えてきた。
機械学習72時間経過
42時間と比較しても、速度が上がっているのがわかる。
機械学習状況のグラフ(TensorBoard)
Environment
Losses
Policy 1
Policy 2
Policy 3
ちなみにプールの素材はこちらです。
https://assetstore.unity.com/packages/3d/environments/japanese-school-swimming-pool-20487
これらの記事の続編を書くので、また見てね。
20年間ありがとう。これからもよろしく。
おわり