四谷ラボ公式ブログ

四谷ラボはいつでも誰でも自由に参加・研究・交流・発信のできる街のオープンイノベーションラボ

GAFA。水面下の争いに決着。

f:id:yamashin0922:20201030080853p:plain

こんにちは、四谷ラボのやましんです。

晴れて結婚20周年を迎えることができました!!

こうやって、書きたいときにブログを書けるのも、周りの人が支えてくれるからだね。

ありがとう!

結果発表

さっそく「GAFA。水面下の争いに決着。」を目撃すべく、この映像を見て欲しい。

youtu.be

後半の追い上げがハンパなく、GAFAの中で優勝を獲得したのはAppleでした。

※仮想環境で文字列の泳ぎを学習させて、競わせたものであり、各社の業績とは無関係です。

今回挑戦した内容

バーチャル空間でGoogleAppleFacebookそしてamazonの文字列をくねらせて、水泳の機械学習に挑戦しました。

機械学習の環境

  • Unity バージョン2020.1.10f1

  • ml-agents version 1.1.0

https://github.com/Unity-Technologies/ml-agents

文字列を泳がせて学習するアイデア

何はともあれ、文字列を構成する文字を3次元化する。

f:id:yamashin0922:20201030004256p:plain

こんな感じ。これは「よつやらぼ」だよ。わかったよね。

文字を3次元化できるアセットがあったので、これを利用しました。

https://assetstore.unity.com/packages/tools/particles-effects/flyingtext3d-3627

直接関係ないけど、最近3Dプリントが楽しくて、上の文字列を具現化してみた。

f:id:yamashin0922:20201030012816j:plain

次に、3次元化した文字列に浮力を与える。

めちゃくちゃいいのがあった。浮力のシミュレーションがとてもうまく再現できる。 f:id:yamashin0922:20201030002649g:plain

https://github.com/dbrizov/NaughtyWaterBuoyancy

さらに、隣り合う文字同士を関節のように接続する。

Unity configurable jointを使って、隣り合う文字同士を接続することにした。

上から見ると、各間接の可動域が一律で180°であることがわかるね。

それと、文字同士の衝突判定はしないよ。

youtu.be

最後は、こんな感じにできました。

youtu.be

よし!他の文字列も同じ要領で作るぞ。

実装

上記で作った文字列のように、文字間の可動域を一定(最大180°)にした状態で、各関節をどのように運動させれば、大きな推進力が得られるかを機械学習(PPO)で推定する。

機械学習の報酬とペナルティ。

ポイントは、報酬とペナルティ。

試行錯誤した結果が次のもの。

  • Time penalty

 時間が経過すると共にマイナス(ペナルティ)

  • Velocity Rewards

 文字列のゴール(プール左端)に対する相対速度が速ければ報酬。

  • Direction penalty

 文字列がゴールに向いていれば報酬。

  • Course out penalty

 文字列がレーンからずれたらペナルティ

  • Reached target

 文字列がゴール到達すれば報酬

  • Fell off platform

 文字列がプラットフォームから落ちるとペナルティ

プログラム

このプログラムは、最も重要な ml-agents の Agentクラスを継承したクラスです。 今回、有料アセット(文字列3次元化、小学校プール)を使用しているので、githubには共有しません。 (Unity有料アセットを含むプログラムのオープンソース化(バイナリ化?)について、知っている人教えて下さい。)

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using FlyingText3D;
using NaughtyWaterBuoyancy;
using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine.UI;

public class SwimmingAgent : Agent
{
    [Header("Specific to SwimmingAgent")]

    private List<string> colors = new List<string>{"white", "black", "red","blue","yellow"};

    public string SwimingTextString;
    public string SwimingTextColor;

    public int course_num;

    private List<ConfigurableJoint> joint_list = new List<ConfigurableJoint>();
    private List<Vector3> current_normalized_rot = new List<Vector3>();
    private List<GameObject> Text3Ds = new List<GameObject>();
    
    private Rigidbody initialCharacter = null;

    public float maxJointSpring;

    public float jointDampen;
    public float maxJointForceLimit;


    public rerowrds total;
    public Text  totalText;
    public Text  totalTotalText;
    
    public rerowrds time;
    public Text  timeText;
    public Text  timeTotalText;


    public rerowrds velocity;
    public Text velocityText;
    public Text velocityTotalText;

    public rerowrds direction;
    public Text directionText;
    public Text directionTotalText;

    public rerowrds touch_target;
    public Text touch_targetText;
    public Text touch_targetTotalText;

    public rerowrds course_out;
    public Text course_outText;
    public Text course_outTotalText;
    Vector3 m_DirToTarget;
    //private static FlyingText instance;Vector3 m_DirToTarget    
    Quaternion m_LookRotation;
    Matrix4x4 m_TargetDirMatrix;
    //private int color_index = 0;

    GameObject CreateFlyingText(String text, float pos)
    {
        FlyingText.colliderType = ColliderType.ConvexMesh;
        FlyingText.defaultSize = 0.8f;
        FlyingText.defaultDepth = 0.1f;
        GameObject objectParent = FlyingText.GetObjects("<color=" + SwimingTextColor +">" + text) ;

               // GameObject objectParent2 = FlyingText.GetObjects(text);
        GameObject charPolygon = objectParent.transform.GetChild(0).gameObject;

        FloatingObject fo = charPolygon.AddComponent<FloatingObject>() as FloatingObject;
        fo.dragInWater = 0.0001f;
        fo.angularDragInWater = 0.0001f;
        //fo.density = 0.75f;

        //Vector3 org_pos = new Vector3(9.5f, 1.0f, 0.0f);
        Vector3 org_pos = new Vector3(0.0f, 1.0f, 0.0f);
        //objectParent.transform.rotation = new Quaternion(90.0f, -90.0f, -90.0f, 0.0f);
        objectParent.transform.position = new Vector3(org_pos.x + pos, org_pos.y, org_pos.z);
        return charPolygon;
    }


    public override void Initialize()
    {
        CreateText3Ds();
        SetResetParameters();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        for (int i = 0; i < Text3Ds.Count(); i++)
        {
            GameObject charPolygon = Text3Ds[i];
            /*
            sensor.AddObservation(charPolygon.transform.rotation);
            sensor.AddObservation(charPolygon.transform.position);
            */
            Rigidbody rb = charPolygon.GetComponent<Rigidbody>();
            sensor.AddObservation(rb.velocity);
            sensor.AddObservation(rb.transform.position);
            sensor.AddObservation(rb.transform.rotation);
            sensor.AddObservation(rb.centerOfMass);
        }
        for(int i = 0;i < joint_list.Count();i++)
        {
            sensor.AddObservation(current_normalized_rot[i]);
        }
    }
   // public float speed = 10;
    Vector3 PrevLocation;

    public override void OnActionReceived(float[] vectorAction)
    {

        int idx = 0;
        for(int i = 0;i < joint_list.Count();i++)
        {
            SetJointTargetRotation(i,0.0f, vectorAction[idx++], 0.0f);
            //SetJointStrength(joint_list[i],vectorAction[idx++]);
            SetJointStrength(joint_list[i],1.0f);
        }
        
        for (int i = 0; i < Text3Ds.Count(); i++)
        {
            Rigidbody rb = Text3Ds[i].GetComponent<Rigidbody>();
            rb.centerOfMass = new Vector3(0f, (vectorAction[idx++]-0.5f), 0f);
        }
        
        CalcRewards();
    }

    public struct rerowrds
    {
        public Boolean is_use;
        public float total;
        public float current_value;
        public float rate;
        public Boolean is_penalty;
    };

    void SetScores()
    {
        if (timeText)
        {
            timeText.text = time.current_value.ToString();
            timeTotalText.text = time.total.ToString();

            velocityText.text = velocity.current_value.ToString();
            velocityTotalText.text = velocity.total.ToString();

            directionText.text = direction.current_value.ToString();
            directionTotalText.text = direction.total.ToString();
            
            course_outText.text = course_out.current_value.ToString();
            course_outTotalText.text = course_out.total.ToString();

            touch_targetText.text = touch_target.current_value.ToString();
            touch_targetTotalText.text = touch_target.total.ToString();

            totalText.text = total.current_value.ToString();
            totalTotalText.text = total.total.ToString();
        }
    }
    void CalcRewards()
    {

        //Time penalty
        if(time.is_use)
        {
            time.current_value = time.rate * (time.is_penalty ? -1:1);
            AddReward(time.current_value);
            time.total += time.current_value;
        }
       
        //Velocity Rewards
        if (velocity.is_use)
        {
            velocity.current_value = -initialCharacter.velocity.x * velocity.rate * (velocity.is_penalty ? -1:1);
            AddReward(velocity.current_value);
            velocity.total += velocity.current_value;
        }

        //Direction penalty
        if(direction.is_use)
        {
            direction.current_value = (.5f - Math.Abs(initialCharacter.transform.rotation.y)) * direction.rate * (direction.is_penalty ? -1:1);
            AddReward(direction.current_value);
            direction.total += direction.current_value;
        }

        //Course out penalty
        if(course_out.is_use)
        {
            course_out.current_value = Math.Abs(initialCharacter.transform.position.z+(5.0f - course_num * 2.5f)) * course_out.rate * (course_out.is_penalty ? -1:1);
            AddReward(course_out.current_value);
            course_out.total += course_out.current_value;
            /*
            course_out.current_value = Math.Abs(initialCharacter.velocity.z) * course_out.rate * (course_out.is_penalty ? -1:1);
            AddReward(course_out.current_value);
            course_out.total += course_out.current_value;
            */     
        }
 
        // Reached target
        if(touch_target.is_use)
        {
            float distanceToTarget =   initialCharacter.transform.position.x - Target.position.x;
            if (distanceToTarget < 2f)
            {
                touch_target.current_value = touch_target.rate * (touch_target.is_penalty ? -1:1);
                AddReward(touch_target.current_value);
                touch_target.total += touch_target.current_value;
                EndEpisode();
            }
        }
        touch_target.current_value = 0;
        // Fell off platform
        if (initialCharacter.transform.position.y < -20)
        {
            EndEpisode();
        }
        total.current_value = time.total + velocity.total + direction.total + course_out.total + touch_target.total;
        total.total += total.current_value;


        SetScores();
    }
    public Transform Target;

    public override void OnEpisodeBegin()
    {
        SetResetParameters();
    }

    public override void Heuristic(float[] actionsOut)
    {
        actionsOut[0] = -Input.GetAxis("Horizontal");
        actionsOut[1] = Input.GetAxis("Vertical");
    }
    //int a=0;
    private void ResetText3Ds()
    {
        float interval = 0.0f;
        float current_pos = 0.0f;
        SetConfigrableJoint();
        for (int i = 0; i < Text3Ds.Count(); i++)
        {
            GameObject charPolygon = Text3Ds[i];
            //Vector3 org_pos = new Vector3(9.5f, -9.0f, - 5.0f + course_num * 2.5f );
            Vector3 org_pos = new Vector3(0.0f, -9.0f, - 5.0f + course_num * 2.5f );
            charPolygon.transform.position = new Vector3(org_pos.x + current_pos, org_pos.y, org_pos.z);
            charPolygon.transform.rotation = new Quaternion(0.0f,0.0f,0.0f,0.0f);
            MeshCollider mc = charPolygon.GetComponent<MeshCollider>();
            //mc.isTrigger = true;
            float width = mc.bounds.size.x;
            current_pos += width + interval;
        }
    }

    private void SetConfigrableJoint()
    {
        for (int i = 0; i < joint_list.Count(); i++)
        {
            joint_list[i].targetRotation = new Quaternion(0.0f, 0.0f, 0.0f, 0.0f);
            current_normalized_rot[i] = new Vector3(0.0f,0.0f,0.0f);
        }
    }
    GameObject pre_charPolygon = null;
    private void CreateConfigrableJoint(GameObject charPolygon)
    {
        if (pre_charPolygon)
        {
            float height = charPolygon.GetComponent<MeshCollider>().bounds.size.y;
            float depth = charPolygon.GetComponent<MeshCollider>().bounds.size.z;
            ConfigurableJoint cj = charPolygon.AddComponent<ConfigurableJoint>() as ConfigurableJoint;
            joint_list.Add(cj);
            current_normalized_rot.Add( new Vector3(0.0f, 0.0f, 0.0f));
            cj.xMotion = ConfigurableJointMotion.Locked;
            cj.yMotion = ConfigurableJointMotion.Locked;
            cj.zMotion = ConfigurableJointMotion.Locked;
            cj.angularXMotion = ConfigurableJointMotion.Locked;
            cj.angularYMotion = ConfigurableJointMotion.Limited;
            cj.angularYLimit = new SoftJointLimit() { limit = 90.0f };
            cj.angularZMotion = ConfigurableJointMotion.Locked;
            /**
            cj.angularZMotion = ConfigurableJointMotion.Limited;
            cj.angularZLimit = new SoftJointLimit() { limit = 90.0f };            
            */
            cj.anchor = new Vector3(0.0f, height / 2.0f, depth/2.0f);
            cj.axis = new Vector3(0.0f, 0.0f, -1.0f);
            cj.massScale = 1000;
            cj.connectedMassScale = 1000;
            cj.enableCollision = false;

            //ConfigurableJoint joint = charPolygon.GetComponent<ConfigurableJoint>() as ConfigurableJoint;
            Rigidbody rb = pre_charPolygon.GetComponent<Rigidbody>();
            cj.connectedBody = rb;
        }
        pre_charPolygon = charPolygon;
    }
    private void CreateText3Ds()
    {
        float interval = 0.0f;
        float current_pos = 0.0f;
        char[] c = SwimingTextString.ToCharArray();

        //Rigidbody pre_Rigidbody = null;
        for (int i = 0; i < c.Length; i++)
        {
            string s = new String(new char[] { c[i] });
            
            GameObject charPolygon = CreateFlyingText(s, current_pos);
           // color_index ++;
            charPolygon.layer = 8;

            Text3Ds.Add(charPolygon);
            MeshCollider mc = charPolygon.GetComponent<MeshCollider>();
            float width = mc.bounds.size.x;
            current_pos += width + interval;
            CreateConfigrableJoint(charPolygon);
            Rigidbody rb = charPolygon.GetComponent<Rigidbody>();
            rb.centerOfMass = new Vector3(0f, -0.5f, 0f);
            if(!initialCharacter)
            {
                initialCharacter = rb;
                //initialCharacter.isKinematic = true;
            }
        }
    }
    public void ResetRewords()
    {
        total.total = 0.0f;
        total.current_value = 0.0f;
        total.is_penalty = false;
        total.rate = 1.0f;

        touch_target.is_use =true;
        touch_target.total = 0.0f;
        touch_target.current_value = 0.0f;
        touch_target.is_penalty = false;
        touch_target.rate = 1000.0f;

        velocity.is_use =true;
        velocity.total = 0.0f;
        velocity.current_value = 0.0f;        
        velocity.is_penalty = false;
        velocity.rate = 1.0f;

        direction.is_use =true;
        direction.total = 0.0f;
        direction.current_value = 0.0f;        
        direction.is_penalty = false;
        direction.rate = 1.0f;

        time.is_use =true;
        time.total = 0.0f;
        time.current_value = 0.0f;        
        time.is_penalty = true;
        time.rate = 1.0f;

        course_out.is_use =true;
        course_out.total = 0.0f;
        course_out.current_value = 0.0f;
        course_out.is_penalty = true;
        course_out.rate = .1f;

    }
    public void SetResetParameters()
    {
        ResetRewords();
        ResetText3Ds();
    }

    public void SetJointTargetRotation(int  i, float x, float y, float z)
    {
        ConfigurableJoint joint = joint_list[i];
        x = (x + 1f) * 0.5f;
        y = (y + 1f) * 0.5f;
        z = (z + 1f) * 0.5f;

        var xRot = Mathf.Lerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, x);
        var yRot = Mathf.Lerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, y);
        var zRot = Mathf.Lerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, z);
        current_normalized_rot[i] = new Vector3(Mathf.InverseLerp(joint.lowAngularXLimit.limit, joint.highAngularXLimit.limit, xRot),
                                                Mathf.InverseLerp(-joint.angularYLimit.limit, joint.angularYLimit.limit, yRot),
                                                Mathf.InverseLerp(-joint.angularZLimit.limit, joint.angularZLimit.limit, zRot));
        joint.targetRotation = Quaternion.Euler(xRot, yRot, zRot);

    }


    public void SetJointStrength(ConfigurableJoint joint, float strength)
    {
        var rawVal = (strength + 1f) * 0.5f * maxJointForceLimit;
        var jd = new JointDrive
        {
            positionSpring = maxJointSpring,
            positionDamper = jointDampen,
            maximumForce = rawVal
        };
        joint.slerpDrive = jd;
        //currentStrength = jd.maximumForce;
    }

}

機械学習時間の文字列の泳ぎの変化

GAFA水泳トレーニング開始直後

みんな、がんばれー。終わったら、はちみつレモンあるわよ。

youtu.be

機械学習2時間経過

ゴールは左側なんだけどなぁ。

youtu.be

機械学習5時間経過

全体的にゴールのある左側へ向かっているのがわかる。

ちなみに、オリンピック水泳選手は、小学生でも1日4,5時間泳ぐこともあるらしい。その辺にいる両生類より水に長く入ってるんちゃう?いい意味で。

youtu.be

機械学習42時間経過

まっすぐ、泳げる個体が増えてきた。

youtu.be

機械学習72時間経過

42時間と比較しても、速度が上がっているのがわかる。

youtu.be

機械学習状況のグラフ(TensorBoard)

Environment

f:id:yamashin0922:20201030023437p:plain

Losses

f:id:yamashin0922:20201030023432p:plain

Policy 1

f:id:yamashin0922:20201030023428p:plain

Policy 2

f:id:yamashin0922:20201030023422p:plain

Policy 3

f:id:yamashin0922:20201030023417p:plain

ちなみにプールの素材はこちらです。

https://assetstore.unity.com/packages/3d/environments/japanese-school-swimming-pool-20487

これらの記事の続編を書くので、また見てね。

blog.428lab.net

blog.428lab.net

20年間ありがとう。これからもよろしく。

おわり