From 5f319dc417e351618201c57b818aaea08f5666c3 Mon Sep 17 00:00:00 2001
From: Jesse Mapel <jam826@nau.edu>
Date: Mon, 25 Nov 2019 15:31:26 -0700
Subject: [PATCH] Adds Rotation class (#311)

* In progress rotations

* Added more tests

* Removed unsued rotation methods and added more tests

* One more exception test

* Added docs to Rotation class
---
 CMakeLists.txt                |   5 +-
 include/Rotation.h            | 136 +++++++++++
 src/Rotation.cpp              | 271 +++++++++++++++++++++
 tests/ctests/RotationTest.cpp | 438 ++++++++++++++++++++++++++++++++++
 4 files changed, 848 insertions(+), 2 deletions(-)
 create mode 100644 include/Rotation.h
 create mode 100644 src/Rotation.cpp
 create mode 100644 tests/ctests/RotationTest.cpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3880b1e..9a1316e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -22,9 +22,10 @@ find_package(nlohmann_json REQUIRED)
 
 # Library setup
 add_library(ale SHARED
-            ${CMAKE_CURRENT_SOURCE_DIR}/src/ale.cpp)
+            ${CMAKE_CURRENT_SOURCE_DIR}/src/ale.cpp
+            ${CMAKE_CURRENT_SOURCE_DIR}/src/Rotation.cpp)
 # Alias a scoped target for safer linking in downstream projects
-set(ALE_HEADERS "include/ale.h")
+set(ALE_HEADERS "include/ale.h, include/Rotation.h")
 set_target_properties(ale PROPERTIES
                       VERSION       ${PROJECT_VERSION}
                       SOVERSION     0
diff --git a/include/Rotation.h b/include/Rotation.h
new file mode 100644
index 0000000..49f0958
--- /dev/null
+++ b/include/Rotation.h
@@ -0,0 +1,136 @@
+#ifndef ALE_ROTATION_H
+#define ALE_ROTATION_H
+
+#include <memory>
+#include <vector>
+
+namespace ale {
+
+  enum RotationInterpolation {
+    slerp, // Spherical interpolation
+    nlerp // Normalized linear interpolation
+  };
+
+  /**
+   * A generic 3D rotation.
+   */
+  class Rotation {
+    public:
+      /**
+       * Construct a default identity rotation.
+       */
+      Rotation();
+      /**
+       * Construct a rotation from a quaternion.
+       *
+       * @param w The scalar component of the quaternion.
+       * @param x The x value of the vector component of the quaternion.
+       * @param y The y value of the vector component of the quaternion.
+       * @param z The z value of the vector component of the quaternion.
+       */
+      Rotation(double w, double x, double y, double z);
+      /**
+       * Construct a rotation from a rotation matrix.
+       *
+       * @param matrix The rotation matrix in row-major order.
+       */
+      Rotation(const std::vector<double>& matrix);
+      /**
+       * Construct a rotation from a set of Euler angle rotations.
+       *
+       * @param angles A vector of rotations about the axes.
+       * @param axes The vector of axes to rotate about, in order.
+       *             0 is X, 1 is Y, and 2 is Z.
+       */
+      Rotation(const std::vector<double>& angles, const std::vector<int>& axes);
+      /**
+       * Construct a rotation from a rotation about an axis.
+       *
+       * @param axis The axis of rotation.
+       * @param theta The rotation about the axis in radians.
+       */
+      Rotation(const std::vector<double>& axis, double theta);
+      ~Rotation();
+
+      // Special member functions
+      Rotation(Rotation && other) noexcept;
+      Rotation& operator=(Rotation && other) noexcept;
+
+      Rotation(const Rotation& other);
+      Rotation& operator=(const Rotation& other);
+
+      // Type specific accessors
+      /**
+       * The rotation as a quaternion.
+       *
+       * @return The rotation as a scalar-first quaternion (w, x, y, z).
+       */
+      std::vector<double> toQuaternion() const;
+      /**
+       * The rotation as a rotation matrix.
+       *
+       * @return The rotation as a rotation matrix in row-major order.
+       */
+      std::vector<double> toRotationMatrix() const;
+      /**
+       * Create a state rotation matrix from the rotation and an angula velocity.
+       *
+       * @param av The angular velocity vector.
+       *
+       * @return The state rotation matrix in row-major order.
+       */
+      std::vector<double> toStateRotationMatrix(const std::vector<double> &av) const;
+      /**
+       * The rotation as Euler angles.
+       *
+       * @param axes The axis order. 0 is X, 1 is Y, and 2 is Z.
+       *
+       * @return The rotations about the axes in radians.
+       */
+      std::vector<double> toEuler(const std::vector<int>& axes) const;
+      /**
+       * The rotation as a rotation about an axis.
+       *
+       * @return the axis of rotation and rotation in radians.
+       */
+      std::pair<std::vector<double>, double> toAxisAngle() const;
+
+      // Generic rotation operations
+      /**
+       * Rotate a vector
+       *
+       * @param vector The vector to rotate. Cab be a 3 element position or 6 element state.
+       * @param av The angular velocity to use when rotating state vectors. Defaults to 0.
+       *
+       * @return The rotated vector.
+       */
+      std::vector<double> operator()(const std::vector<double>& vector, const std::vector<double>& av = {0.0, 0.0, 0.0}) const;
+      /**
+       * Get the inverse rotation.
+       */
+      Rotation inverse() const;
+      /**
+       * Chain this rotation with another rotation.
+       *
+       * Rotations are sequenced right to left.
+       */
+      Rotation operator*(const Rotation& rightRotation) const;
+      /**
+       * Interpolate between this rotation and another rotation.
+       *
+       * @param t The distance to interpolate. 0 is this and 1 is the next rotation.
+       * @param interpType The type of rotation interpolation to use.
+       *
+       * @param The interpolated rotation.
+       */
+      Rotation interpolate(const Rotation& nextRotation, double t, RotationInterpolation interpType) const;
+
+    private:
+      // Implementation class
+      class Impl;
+      // Pointer to internal rotation implementation.
+      std::unique_ptr<Impl> m_impl;
+  };
+}
+
+#endif
diff --git a/src/Rotation.cpp b/src/Rotation.cpp
new file mode 100644
index 0000000..46d68e0
--- /dev/null
+++ b/src/Rotation.cpp
@@ -0,0 +1,271 @@
+#include "Rotation.h"
+
+#include <algorithm>
+#include <exception>
+
+#include <Eigen/Geometry>
+
+namespace ale {
+
+///////////////////////////////////////////////////////////////////////////////
+// Helper Functions
+///////////////////////////////////////////////////////////////////////////////
+
+  // Linearly interpolate between two values
+  double linearInterpolate(double x, double y, double t) {
+    return x + t * (y - x);
+  }
+
+
+  // Helper function to convert an axis number into a unit Eigen vector down that axis.
+  Eigen::Vector3d axis(int axisIndex) {
+    switch (axisIndex) {
+      case 0:
+        return Eigen::Vector3d::UnitX();
+        break;
+      case 1:
+        return Eigen::Vector3d::UnitY();
+        break;
+      case 2:
+        return Eigen::Vector3d::UnitZ();
+        break;
+      default:
+        throw std::invalid_argument("Axis index must be 0, 1, or 2.");
+    }
+  }
+
+
+  /**
+   * Create the skew symmetric matrix used when computing the derivative of a
+   * rotation matrix.
+   *
+   * This is actually the transpose of the skew AV matrix because we define AV
+   * as the AV from the destination to the source. This matches how NAIF
+   * defines AV.
+   */
+  Eigen::Quaterniond::Matrix3 avSkewMatrix(
+        const std::vector<double>& av
+  ) {
+    if (av.size() != 3) {
+      throw std::invalid_argument("Angular velocity vector to rotate is the wrong size.");
+    }
+    Eigen::Quaterniond::Matrix3 avMat;
+    avMat <<  0.0,    av[2], -av[1],
+             -av[2],  0.0,    av[0],
+              av[1], -av[0],  0.0;
+    return avMat;
+  }
+
+  ///////////////////////////////////////////////////////////////////////////////
+  // Rotation Impl class
+  ///////////////////////////////////////////////////////////////////////////////
+
+  // Internal representation of the rotation as an Eigen Double Quaternion
+  class Rotation::Impl {
+    public:
+      Impl() : quat(Eigen::Quaterniond::Identity()) { }
+
+
+      Impl(double w, double x, double y, double z) : quat(w, x, y, z) { }
+
+
+      Impl(const std::vector<double>& matrix) {
+        if (matrix.size() != 9) {
+          throw std::invalid_argument("Rotation matrix must be 3 by 3.");
+        }
+        quat = Eigen::Quaterniond(Eigen::Quaterniond::Matrix3(matrix.data()));
+      }
+
+
+      Impl(const std::vector<double>& angles, const std::vector<int>& axes) {
+        if (angles.empty() || axes.empty()) {
+          throw std::invalid_argument("Angles and axes must be non-empty.");
+        }
+        if (angles.size() != axes.size()) {
+          throw std::invalid_argument("Number of angles and axes must be equal.");
+        }
+        quat = Eigen::Quaterniond::Identity();
+
+        for (size_t i = 0; i < angles.size(); i++) {
+          quat *= Eigen::Quaterniond(Eigen::AngleAxisd(angles[i], axis(axes[i])));
+        }
+      }
+
+
+      Impl(const std::vector<double>& axis, double theta) {
+        if (axis.size() != 3) {
+          throw std::invalid_argument("Rotation axis must have 3 elements.");
+        }
+        Eigen::Vector3d eigenAxis((double *) axis.data());
+        quat = Eigen::Quaterniond(Eigen::AngleAxisd(theta, eigenAxis.normalized()));
+      }
+
+
+      Eigen::Quaterniond quat;
+  };
+
+  ///////////////////////////////////////////////////////////////////////////////
+  // Rotation Class
+  ///////////////////////////////////////////////////////////////////////////////
+
+  Rotation::Rotation() :
+        m_impl(new Impl()) { }
+
+
+  Rotation::Rotation(double w, double x, double y, double z) :
+        m_impl(new Impl(w, x, y, z)) { }
+
+
+  Rotation::Rotation(const std::vector<double>& matrix) :
+        m_impl(new Impl(matrix)) { }
+
+
+  Rotation::Rotation(const std::vector<double>& angles, const std::vector<int>& axes) :
+        m_impl(new Impl(angles, axes)) { }
+
+
+  Rotation::Rotation(const std::vector<double>& axis, double theta) :
+        m_impl(new Impl(axis, theta)) { }
+
+
+  Rotation::~Rotation() = default;
+
+
+  Rotation::Rotation(Rotation && other) noexcept = default;
+
+
+  Rotation& Rotation::operator=(Rotation && other) noexcept = default;
+
+
+  // unique_ptr doesn't have a copy constructor so we have to define one
+  Rotation::Rotation(const Rotation& other) : m_impl(new Impl(*other.m_impl)) { }
+
+
+  // unique_ptr doesn't have an assignment operator so we have to define one
+  Rotation& Rotation::operator=(const Rotation& other) {
+    if (this != &other) {
+      m_impl.reset(new Impl(*other.m_impl));
+    }
+    return *this;
+  }
+
+
+  std::vector<double> Rotation::toQuaternion() const {
+    Eigen::Quaterniond normalized = m_impl->quat.normalized();
+    return {normalized.w(), normalized.x(), normalized.y(), normalized.z()};
+  }
+
+
+  std::vector<double> Rotation::toRotationMatrix() const {
+    Eigen::Quaterniond::RotationMatrixType mat = m_impl->quat.toRotationMatrix();
+    return std::vector<double>(mat.data(), mat.data() + mat.size());
+  }
+
+
+  std::vector<double> Rotation::toStateRotationMatrix(const std::vector<double> &av) const {
+    Eigen::Quaterniond::Matrix3 rotMat = m_impl->quat.toRotationMatrix();
+    Eigen::Quaterniond::Matrix3 avMat = avSkewMatrix(av);
+    Eigen::Quaterniond::Matrix3 dtMat = rotMat * avMat;
+    return {rotMat(0,0), rotMat(0,1), rotMat(0,2), 0.0,         0.0,         0.0,
+            rotMat(1,0), rotMat(1,1), rotMat(1,2), 0.0,         0.0,         0.0,
+            rotMat(2,0), rotMat(2,1), rotMat(2,2), 0.0,         0.0,         0.0,
+            dtMat(0,0),  dtMat(0,1),  dtMat(0,2),  rotMat(0,0), rotMat(0,1), rotMat(0,2),
+            dtMat(1,0),  dtMat(1,1),  dtMat(1,2),  rotMat(1,0), rotMat(1,1), rotMat(1,2),
+            dtMat(2,0),  dtMat(2,1),  dtMat(2,2),  rotMat(2,0), rotMat(2,1), rotMat(2,2)};
+  }
+
+
+  std::vector<double> Rotation::toEuler(const std::vector<int>& axes) const {
+    if (axes.size() != 3) {
+      throw std::invalid_argument("Must have 3 axes to convert to Euler angles.");
+    }
+    if (axes[0] < 0 || axes[0] > 2 ||
+        axes[1] < 0 || axes[1] > 2 ||
+        axes[2] < 0 || axes[2] > 2) {
+      throw std::invalid_argument("Invalid axis number.");
+    }
+    Eigen::Vector3d angles = m_impl->quat.toRotationMatrix().eulerAngles(
+          axes[0],
+          axes[1],
+          axes[2]);
+    return std::vector<double>(angles.data(), angles.data() + angles.size());
+  }
+
+
+  std::pair<std::vector<double>, double> Rotation::toAxisAngle() const {
+    Eigen::AngleAxisd eigenAxisAngle(m_impl->quat);
+    std::pair<std::vector<double>, double> axisAngle;
+    axisAngle.first = std::vector<double>(
+          eigenAxisAngle.axis().data(),
+          eigenAxisAngle.axis().data() + eigenAxisAngle.axis().size()
+    );
+    axisAngle.second = eigenAxisAngle.angle();
+    return axisAngle;
+  }
+
+
+  std::vector<double> Rotation::operator()(
+        const std::vector<double>& vector,
+        const std::vector<double>& av
+  ) const {
+    if (vector.size() == 3) {
+      Eigen::Map<Eigen::Vector3d> eigenVector((double *)vector.data());
+      Eigen::Vector3d rotatedVector = m_impl->quat._transformVector(eigenVector);
+      return std::vector<double>(rotatedVector.data(), rotatedVector.data() + rotatedVector.size());
+    }
+    else if (vector.size() == 6) {
+      Eigen::Map<Eigen::Vector3d> positionVector((double *)vector.data());
+      Eigen::Map<Eigen::Vector3d> velocityVector((double *)vector.data() + 3);
+      Eigen::Quaterniond::Matrix3 rotMat = m_impl->quat.toRotationMatrix();
+      Eigen::Quaterniond::Matrix3 avMat = avSkewMatrix(av);
+      Eigen::Quaterniond::Matrix3 rotationDerivative = rotMat * avMat;
+      Eigen::Vector3d rotatedPosition = rotMat * positionVector;
+      Eigen::Vector3d rotatedVelocity = rotMat * velocityVector + rotationDerivative * positionVector;
+      return {rotatedPosition(0), rotatedPosition(1), rotatedPosition(2),
+              rotatedVelocity(0), rotatedVelocity(1), rotatedVelocity(2)};
+    }
+    else {
+      throw std::invalid_argument("Vector to rotate is the wrong size.");
+    }
+  }
+
+
+  Rotation Rotation::inverse() const {
+    Eigen::Quaterniond inverseQuat = m_impl->quat.inverse();
+    return Rotation(inverseQuat.w(), inverseQuat.x(), inverseQuat.y(), inverseQuat.z());
+  }
+
+
+  Rotation Rotation::operator*(const Rotation& rightRotation) const {
+    Eigen::Quaterniond combinedQuat = m_impl->quat * rightRotation.m_impl->quat;
+    return Rotation(combinedQuat.w(), combinedQuat.x(), combinedQuat.y(), combinedQuat.z());
+  }
+
+
+  Rotation Rotation::interpolate(
+        const Rotation& nextRotation,
+        double t,
+        RotationInterpolation interpType
+  ) const {
+    Eigen::Quaterniond interpQuat;
+    switch (interpType) {
+      case slerp:
+        interpQuat = m_impl->quat.slerp(t, nextRotation.m_impl->quat);
+        break;
+      case nlerp:
+        interpQuat = Eigen::Quaterniond(
+              linearInterpolate(m_impl->quat.w(), nextRotation.m_impl->quat.w(), t),
+              linearInterpolate(m_impl->quat.x(), nextRotation.m_impl->quat.x(), t),
+              linearInterpolate(m_impl->quat.y(), nextRotation.m_impl->quat.y(), t),
+              linearInterpolate(m_impl->quat.z(), nextRotation.m_impl->quat.z(), t)
+        );
+        interpQuat.normalize();
+        break;
+      default:
+        throw std::invalid_argument("Unsupported rotation interpolation type.");
+        break;
+    }
+    return Rotation(interpQuat.w(), interpQuat.x(), interpQuat.y(), interpQuat.z());
+  }
+
+}
diff --git a/tests/ctests/RotationTest.cpp b/tests/ctests/RotationTest.cpp
new file mode 100644
index 0000000..e8b8390
--- /dev/null
+++ b/tests/ctests/RotationTest.cpp
@@ -0,0 +1,438 @@
+#include "gtest/gtest.h"
+
+#include "Rotation.h"
+
+#include <cmath>
+#include <exception>
+
+using namespace std;
+using namespace ale;
+
+TEST(RotationTest, DefaultConstructor) {
+  Rotation defaultRotation;
+  vector<double> defaultQuat = defaultRotation.toQuaternion();
+  ASSERT_EQ(defaultQuat.size(), 4);
+  EXPECT_NEAR(defaultQuat[0], 1.0, 1e-10);
+  EXPECT_NEAR(defaultQuat[1], 0.0, 1e-10);
+  EXPECT_NEAR(defaultQuat[2], 0.0, 1e-10);
+  EXPECT_NEAR(defaultQuat[3], 0.0, 1e-10);
+}
+
+TEST(RotationTest, QuaternionConstructor) {
+  Rotation rotation(1.0/sqrt(2), 1.0/sqrt(2), 0.0, 0.0);
+  vector<double> quat = rotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0], 1.0/sqrt(2), 1e-10);
+  EXPECT_NEAR(quat[1], 1.0/sqrt(2), 1e-10);
+  EXPECT_NEAR(quat[2], 0.0, 1e-10);
+  EXPECT_NEAR(quat[3], 0.0, 1e-10);
+}
+
+TEST(RotationTest, MatrixConstructor) {
+  Rotation rotation(
+        {0.0, 0.0, 1.0,
+         1.0, 0.0, 0.0,
+         0.0, 1.0, 0.0});
+  vector<double> quat = rotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0], -0.5, 1e-10);
+  EXPECT_NEAR(quat[1],  0.5, 1e-10);
+  EXPECT_NEAR(quat[2],  0.5, 1e-10);
+  EXPECT_NEAR(quat[3],  0.5, 1e-10);
+}
+
+TEST(RotationTest, BadMatrixConstructor) {
+  ASSERT_THROW(Rotation({0.0, 0.0, 1.0,
+                         1.0, 0.0, 0.0,
+                         0.0, 1.0, 0.0,
+                         1.0, 0.0, 2.0}),
+               std::invalid_argument);
+}
+
+TEST(RotationTest, SingleAngleConstructor) {
+  std::vector<double> angles;
+  angles.push_back(M_PI);
+  std::vector<int> axes;
+  axes.push_back(0);
+  Rotation rotation(angles, axes);
+  vector<double> quat = rotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0], 0.0, 1e-10);
+  EXPECT_NEAR(quat[1], 1.0, 1e-10);
+  EXPECT_NEAR(quat[2], 0.0, 1e-10);
+  EXPECT_NEAR(quat[3], 0.0, 1e-10);
+}
+
+TEST(RotationTest, MultiAngleConstructor) {
+  Rotation rotation({M_PI/2, -M_PI/2, M_PI}, {0, 1, 2});
+  vector<double> quat = rotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0],  0.5, 1e-10);
+  EXPECT_NEAR(quat[1], -0.5, 1e-10);
+  EXPECT_NEAR(quat[2], -0.5, 1e-10);
+  EXPECT_NEAR(quat[3],  0.5, 1e-10);
+}
+
+TEST(RotationTest, DifferentAxisAngleCount) {
+  std::vector<double> angles;
+  angles.push_back(M_PI);
+  std::vector<int> axes = {0, 1, 2};
+  ASSERT_THROW(Rotation(angles, axes), std::invalid_argument);
+}
+
+TEST(RotationTest, EmptyAxisAngle) {
+  std::vector<double> angles;
+  std::vector<int> axes = {0, 1, 2};
+  ASSERT_THROW(Rotation(angles, axes), std::invalid_argument);
+}
+
+TEST(RotationTest, BadAxisNumber) {
+  std::vector<double> angles;
+  angles.push_back(M_PI);
+  std::vector<int> axes;
+  axes.push_back(4);
+  ASSERT_THROW(Rotation(angles, axes), std::invalid_argument);
+}
+
+TEST(RotationTest, AxisAngleConstructor) {
+  Rotation rotation({1.0, 1.0, 1.0}, 2.0 / 3.0 * M_PI);
+  vector<double> quat = rotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0], 0.5, 1e-10);
+  EXPECT_NEAR(quat[1], 0.5, 1e-10);
+  EXPECT_NEAR(quat[2], 0.5, 1e-10);
+  EXPECT_NEAR(quat[3], 0.5, 1e-10);
+}
+
+TEST(RotationTest, BadAxisAngleConstructor) {
+  ASSERT_THROW(Rotation({1.0, 1.0, 1.0, 1.0}, 2.0 / 3.0 * M_PI), std::invalid_argument);
+}
+
+TEST(RotationTest, ToRotationMatrix) {
+  Rotation rotation(-0.5, 0.5, 0.5, 0.5);
+  vector<double> mat = rotation.toRotationMatrix();
+  ASSERT_EQ(mat.size(), 9);
+  EXPECT_NEAR(mat[0], 0.0, 1e-10);
+  EXPECT_NEAR(mat[1], 0.0, 1e-10);
+  EXPECT_NEAR(mat[2], 1.0, 1e-10);
+  EXPECT_NEAR(mat[3], 1.0, 1e-10);
+  EXPECT_NEAR(mat[4], 0.0, 1e-10);
+  EXPECT_NEAR(mat[5], 0.0, 1e-10);
+  EXPECT_NEAR(mat[6], 0.0, 1e-10);
+  EXPECT_NEAR(mat[7], 1.0, 1e-10);
+  EXPECT_NEAR(mat[8], 0.0, 1e-10);
+}
+
+TEST(RotationTest, ToStateRotationMatrix) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  std::vector<double> av = {2.0 / 3.0 * M_PI, 2.0 / 3.0 * M_PI, 2.0 / 3.0 * M_PI};
+  vector<double> mat = rotation.toStateRotationMatrix(av);
+  ASSERT_EQ(mat.size(), 36);
+  EXPECT_NEAR(mat[0], 0.0, 1e-10);
+  EXPECT_NEAR(mat[1], 0.0, 1e-10);
+  EXPECT_NEAR(mat[2], 1.0, 1e-10);
+  EXPECT_NEAR(mat[3], 0.0, 1e-10);
+  EXPECT_NEAR(mat[4], 0.0, 1e-10);
+  EXPECT_NEAR(mat[5], 0.0, 1e-10);
+
+  EXPECT_NEAR(mat[6], 1.0, 1e-10);
+  EXPECT_NEAR(mat[7], 0.0, 1e-10);
+  EXPECT_NEAR(mat[8], 0.0, 1e-10);
+  EXPECT_NEAR(mat[9], 0.0, 1e-10);
+  EXPECT_NEAR(mat[10], 0.0, 1e-10);
+  EXPECT_NEAR(mat[11], 0.0, 1e-10);
+
+  EXPECT_NEAR(mat[12], 0.0, 1e-10);
+  EXPECT_NEAR(mat[13], 1.0, 1e-10);
+  EXPECT_NEAR(mat[14], 0.0, 1e-10);
+  EXPECT_NEAR(mat[15], 0.0, 1e-10);
+  EXPECT_NEAR(mat[16], 0.0, 1e-10);
+  EXPECT_NEAR(mat[17], 0.0, 1e-10);
+
+  EXPECT_NEAR(mat[18], 2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(mat[19], -2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(mat[20], 0.0, 1e-10);
+  EXPECT_NEAR(mat[21], 0.0, 1e-10);
+  EXPECT_NEAR(mat[22], 0.0, 1e-10);
+  EXPECT_NEAR(mat[23], 1.0, 1e-10);
+
+  EXPECT_NEAR(mat[24], 0.0, 1e-10);
+  EXPECT_NEAR(mat[25], 2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(mat[26], -2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(mat[27], 1.0, 1e-10);
+  EXPECT_NEAR(mat[28], 0.0, 1e-10);
+  EXPECT_NEAR(mat[29], 0.0, 1e-10);
+
+  EXPECT_NEAR(mat[30], -2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(mat[31], 0.0, 1e-10);
+  EXPECT_NEAR(mat[32], 2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(mat[33], 0.0, 1e-10);
+  EXPECT_NEAR(mat[34], 1.0, 1e-10);
+  EXPECT_NEAR(mat[35], 0.0, 1e-10);
+}
+
+TEST(RotationTest, BadAvVectorSize) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  std::vector<double> av = {2.0 / 3.0 * M_PI, 2.0 / 3.0 * M_PI};
+  ASSERT_THROW(rotation.toStateRotationMatrix(av), std::invalid_argument);
+}
+
+TEST(RotationTest, ToEulerXYZ) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  vector<double> angles = rotation.toEuler({0, 1, 2});
+  ASSERT_EQ(angles.size(), 3);
+  EXPECT_NEAR(angles[0], 0.0, 1e-10);
+  EXPECT_NEAR(angles[1], M_PI/2, 1e-10);
+  EXPECT_NEAR(angles[2], M_PI/2, 1e-10);
+}
+
+TEST(RotationTest, ToEulerZYX) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  vector<double> angles = rotation.toEuler({2, 1, 0});
+  ASSERT_EQ(angles.size(), 3);
+  EXPECT_NEAR(angles[0], M_PI/2, 1e-10);
+  EXPECT_NEAR(angles[1], 0.0, 1e-10);
+  EXPECT_NEAR(angles[2], M_PI/2, 1e-10);
+}
+
+TEST(RotationTest, ToEulerWrongNumberOfAxes) {
+  Rotation rotation;
+  ASSERT_ANY_THROW(rotation.toEuler({1, 0}));
+}
+
+TEST(RotationTest, ToEulerBadAxisNumber) {
+  Rotation rotation;
+  ASSERT_THROW(rotation.toEuler({4, 1, 0}), std::invalid_argument);
+}
+
+TEST(RotationTest, ToAxisAngle) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  std::pair<std::vector<double>, double> axisAngle = rotation.toAxisAngle();
+  ASSERT_EQ(axisAngle.first.size(), 3);
+  EXPECT_NEAR(axisAngle.first[0], 1.0 / sqrt(3), 1e-10);
+  EXPECT_NEAR(axisAngle.first[1], 1.0 / sqrt(3), 1e-10);
+  EXPECT_NEAR(axisAngle.first[2], 1.0 / sqrt(3), 1e-10);
+  EXPECT_NEAR(axisAngle.second, 2.0 / 3.0 * M_PI, 1e-10);
+}
+
+TEST(RotationTest, RotateVector) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  vector<double> unitX = {1.0, 0.0, 0.0};
+  vector<double> unitY = {0.0, 1.0, 0.0};
+  vector<double> unitZ = {0.0, 0.0, 1.0};
+  vector<double> rotatedX = rotation(unitX);
+  vector<double> rotatedY = rotation(unitY);
+  vector<double> rotatedZ = rotation(unitZ);
+  ASSERT_EQ(rotatedX.size(), 3);
+  EXPECT_NEAR(rotatedX[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedX[1], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedX[2], 0.0, 1e-10);
+  ASSERT_EQ(rotatedY.size(), 3);
+  EXPECT_NEAR(rotatedY[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedY[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedY[2], 1.0, 1e-10);
+  ASSERT_EQ(rotatedZ.size(), 3);
+  EXPECT_NEAR(rotatedZ[0], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[2], 0.0, 1e-10);
+}
+
+TEST(RotationTest, RotateState) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  std::vector<double> av = {2.0 / 3.0 * M_PI, 2.0 / 3.0 * M_PI, 2.0 / 3.0 * M_PI};
+  vector<double> unitX =  {1.0, 0.0, 0.0, 0.0, 0.0, 0.0};
+  vector<double> unitY =  {0.0, 1.0, 0.0, 0.0, 0.0, 0.0};
+  vector<double> unitZ =  {0.0, 0.0, 1.0, 0.0, 0.0, 0.0};
+  vector<double> unitVX = {0.0, 0.0, 0.0, 1.0, 0.0, 0.0};
+  vector<double> unitVY = {0.0, 0.0, 0.0, 0.0, 1.0, 0.0};
+  vector<double> unitVZ = {0.0, 0.0, 0.0, 0.0, 0.0, 1.0};
+  vector<double> rotatedX = rotation(unitX, av);
+  vector<double> rotatedY = rotation(unitY, av);
+  vector<double> rotatedZ = rotation(unitZ, av);
+  vector<double> rotatedVX = rotation(unitVX, av);
+  vector<double> rotatedVY = rotation(unitVY, av);
+  vector<double> rotatedVZ = rotation(unitVZ, av);
+  ASSERT_EQ(rotatedX.size(), 6);
+  EXPECT_NEAR(rotatedX[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedX[1], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedX[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedX[3], 2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(rotatedX[4], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedX[5], -2.0 / 3.0 * M_PI, 1e-10);
+  ASSERT_EQ(rotatedY.size(), 6);
+  EXPECT_NEAR(rotatedY[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedY[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedY[2], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedY[3], -2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(rotatedY[4], 2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(rotatedY[5], 0.0, 1e-10);
+  ASSERT_EQ(rotatedZ.size(), 6);
+  EXPECT_NEAR(rotatedZ[0], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[3], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[4], -2.0 / 3.0 * M_PI, 1e-10);
+  EXPECT_NEAR(rotatedZ[5], 2.0 / 3.0 * M_PI, 1e-10);
+  ASSERT_EQ(rotatedVX.size(), 6);
+  EXPECT_NEAR(rotatedVX[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[3], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[4], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[5], 0.0, 1e-10);
+  ASSERT_EQ(rotatedVY.size(), 6);
+  EXPECT_NEAR(rotatedVY[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[3], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[4], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[5], 1.0, 1e-10);
+  ASSERT_EQ(rotatedVZ.size(), 6);
+  EXPECT_NEAR(rotatedVZ[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[3], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[4], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[5], 0.0, 1e-10);
+}
+
+TEST(RotationTest, RotateStateNoAv) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  vector<double> unitX =  {1.0, 0.0, 0.0, 0.0, 0.0, 0.0};
+  vector<double> unitY =  {0.0, 1.0, 0.0, 0.0, 0.0, 0.0};
+  vector<double> unitZ =  {0.0, 0.0, 1.0, 0.0, 0.0, 0.0};
+  vector<double> unitVX = {0.0, 0.0, 0.0, 1.0, 0.0, 0.0};
+  vector<double> unitVY = {0.0, 0.0, 0.0, 0.0, 1.0, 0.0};
+  vector<double> unitVZ = {0.0, 0.0, 0.0, 0.0, 0.0, 1.0};
+  vector<double> rotatedX = rotation(unitX);
+  vector<double> rotatedY = rotation(unitY);
+  vector<double> rotatedZ = rotation(unitZ);
+  vector<double> rotatedVX = rotation(unitVX);
+  vector<double> rotatedVY = rotation(unitVY);
+  vector<double> rotatedVZ = rotation(unitVZ);
+  ASSERT_EQ(rotatedX.size(), 6);
+  EXPECT_NEAR(rotatedX[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedX[1], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedX[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedX[3], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedX[4], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedX[5], 0.0, 1e-10);
+  ASSERT_EQ(rotatedY.size(), 6);
+  EXPECT_NEAR(rotatedY[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedY[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedY[2], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedY[3], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedY[4], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedY[5], 0.0, 1e-10);
+  ASSERT_EQ(rotatedZ.size(), 6);
+  EXPECT_NEAR(rotatedZ[0], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[3], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[4], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedZ[5], 0.0, 1e-10);
+  ASSERT_EQ(rotatedVX.size(), 6);
+  EXPECT_NEAR(rotatedVX[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[3], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[4], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedVX[5], 0.0, 1e-10);
+  ASSERT_EQ(rotatedVY.size(), 6);
+  EXPECT_NEAR(rotatedVY[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[3], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[4], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVY[5], 1.0, 1e-10);
+  ASSERT_EQ(rotatedVZ.size(), 6);
+  EXPECT_NEAR(rotatedVZ[0], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[1], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[2], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[3], 1.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[4], 0.0, 1e-10);
+  EXPECT_NEAR(rotatedVZ[5], 0.0, 1e-10);
+}
+
+TEST(RotationTest, RotateWrongSizeAV) {
+  Rotation rotation;
+  ASSERT_NO_THROW(rotation({1.0, 1.0, 1.0}));
+  ASSERT_NO_THROW(rotation({1.0, 1.0, 1.0}, {1.0, 1.0}));
+}
+
+TEST(RotationTest, RotateWrongSizeVector) {
+  Rotation rotation;
+  ASSERT_THROW(rotation({1.0, 1.0, 1.0, 1.0}), std::invalid_argument);
+}
+
+TEST(RotationTest, Inverse) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  Rotation inverseRotation = rotation.inverse();
+  vector<double> quat = inverseRotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0],  0.5, 1e-10);
+  EXPECT_NEAR(quat[1], -0.5, 1e-10);
+  EXPECT_NEAR(quat[2], -0.5, 1e-10);
+  EXPECT_NEAR(quat[3], -0.5, 1e-10);
+}
+
+TEST(RotationTest, MultiplyRotation) {
+  Rotation rotation(0.5, 0.5, 0.5, 0.5);
+  Rotation doubleRotation = rotation * rotation;
+  vector<double> quat = doubleRotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0], -0.5, 1e-10);
+  EXPECT_NEAR(quat[1],  0.5, 1e-10);
+  EXPECT_NEAR(quat[2],  0.5, 1e-10);
+  EXPECT_NEAR(quat[3],  0.5, 1e-10);
+}
+
+TEST(RotationTest, Slerp) {
+  Rotation rotationOne(0.5, 0.5, 0.5, 0.5);
+  Rotation rotationTwo(-0.5, 0.5, 0.5, 0.5);
+  Rotation interpRotation = rotationOne.interpolate(rotationTwo, 0.125, ale::slerp);
+  vector<double> quat = interpRotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0], cos(M_PI * 3.0/8.0), 1e-10);
+  EXPECT_NEAR(quat[1], sin(M_PI * 3.0/8.0) * 1/sqrt(3.0), 1e-10);
+  EXPECT_NEAR(quat[2], sin(M_PI * 3.0/8.0) * 1/sqrt(3.0), 1e-10);
+  EXPECT_NEAR(quat[3], sin(M_PI * 3.0/8.0) * 1/sqrt(3.0), 1e-10);
+}
+
+TEST(RotationTest, SlerpExtrapolate) {
+  Rotation rotationOne(0.5, 0.5, 0.5, 0.5);
+  Rotation rotationTwo(-0.5, 0.5, 0.5, 0.5);
+  Rotation interpRotation = rotationOne.interpolate(rotationTwo, 1.125, ale::slerp);
+  vector<double> quat = interpRotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0], cos(M_PI * 17.0/24.0), 1e-10);
+  EXPECT_NEAR(quat[1], sin(M_PI * 17.0/24.0) * 1/sqrt(3.0), 1e-10);
+  EXPECT_NEAR(quat[2], sin(M_PI * 17.0/24.0) * 1/sqrt(3.0), 1e-10);
+  EXPECT_NEAR(quat[3], sin(M_PI * 17.0/24.0) * 1/sqrt(3.0), 1e-10);
+}
+
+TEST(RotationTest, Nlerp) {
+  Rotation rotationOne(0.5, 0.5, 0.5, 0.5);
+  Rotation rotationTwo(-0.5, 0.5, 0.5, 0.5);
+  Rotation interpRotation = rotationOne.interpolate(rotationTwo, 0.125, ale::nlerp);
+  double scaling = 8.0 / sqrt(57.0);
+  vector<double> quat = interpRotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0], 3.0 / 8.0 * scaling, 1e-10);
+  EXPECT_NEAR(quat[1], 1.0 / 2.0 * scaling, 1e-10);
+  EXPECT_NEAR(quat[2], 1.0 / 2.0 * scaling, 1e-10);
+  EXPECT_NEAR(quat[3], 1.0 / 2.0 * scaling, 1e-10);
+}
+
+TEST(RotationTest, NlerpExtrapolate) {
+  Rotation rotationOne(0.5, 0.5, 0.5, 0.5);
+  Rotation rotationTwo(-0.5, 0.5, 0.5, 0.5);
+  Rotation interpRotation = rotationOne.interpolate(rotationTwo, 1.125, ale::nlerp);
+  double scaling = 8.0 / sqrt(73.0);
+  vector<double> quat = interpRotation.toQuaternion();
+  ASSERT_EQ(quat.size(), 4);
+  EXPECT_NEAR(quat[0], -5.0 / 8.0 * scaling, 1e-10);
+  EXPECT_NEAR(quat[1], 1.0 / 2.0 * scaling, 1e-10);
+  EXPECT_NEAR(quat[2], 1.0 / 2.0 * scaling, 1e-10);
+  EXPECT_NEAR(quat[3], 1.0 / 2.0 * scaling, 1e-10);
+}
-- 
GitLab