From 03e2fc959702e7654c507c4c6125ea1c1b54ecb6 Mon Sep 17 00:00:00 2001 From: Benjamin Culkin Date: Mon, 5 Feb 2018 13:01:44 -0800 Subject: Add dual numbers Dual numbers are a easy way of doing automatic numeric differentiation of expressions. --- base/src/main/java/bjc/utils/math/Dual.java | 83 ++++++ base/src/main/java/bjc/utils/math/DualExpr.java | 328 ++++++++++++++++++++++++ 2 files changed, 411 insertions(+) create mode 100644 base/src/main/java/bjc/utils/math/Dual.java create mode 100644 base/src/main/java/bjc/utils/math/DualExpr.java (limited to 'base/src') diff --git a/base/src/main/java/bjc/utils/math/Dual.java b/base/src/main/java/bjc/utils/math/Dual.java new file mode 100644 index 0000000..53ddc32 --- /dev/null +++ b/base/src/main/java/bjc/utils/math/Dual.java @@ -0,0 +1,83 @@ +package bjc.utils.math; + +/** + * Represents a 'dual' number. + * + * Think imaginary numbers, where instead of i, we add a value d such that d^2 = + * 0. + */ +public class Dual { + /** + * The real part of the dual number. + */ + public double real; + /** + * The dual part of the dual number. + */ + public double dual; + + /** + * Create a new dual with both parts zero. + */ + public Dual() { + real = 0; + dual = 0; + } + + /** + * Create a new dual number with a zero dual part. + * + * @param real + * The real part of the number. + */ + public Dual(double real) { + this.real = real; + this.dual = 0; + } + + /** + * Create a new dual number with a specified dual part. + * + * @param real + * The real part of the number. + * @param dual + * The dual part of the number. + */ + public Dual(double real, double dual) { + this.real = real; + this.dual = dual; + } + + @Override + public String toString() { + return String.format("<%f, %f>", real, dual); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + long temp; + temp = Double.doubleToLongBits(dual); + result = prime * result + (int) (temp ^ (temp >>> 32)); + temp = Double.doubleToLongBits(real); + result = prime * result + (int) (temp ^ (temp >>> 32)); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Dual other = (Dual) obj; + if (Double.doubleToLongBits(dual) != Double.doubleToLongBits(other.dual)) + return false; + if (Double.doubleToLongBits(real) != Double.doubleToLongBits(other.real)) + return false; + return true; + } +} \ No newline at end of file diff --git a/base/src/main/java/bjc/utils/math/DualExpr.java b/base/src/main/java/bjc/utils/math/DualExpr.java new file mode 100644 index 0000000..d0e9acf --- /dev/null +++ b/base/src/main/java/bjc/utils/math/DualExpr.java @@ -0,0 +1,328 @@ +package bjc.utils.math; +/** + * Represents an expression using dual numbers. + * + * Useful for automatically differentiating expressions. + */ +public class DualExpr { + /** + * Represents the various types of dual expressions. + */ + public static enum ExprType { + /** + * A fixed number. + */ + CONSTANT, + /** + * An addition operation. + */ + ADDITION, + /** + * A subtraction operation. + */ + SUBTRACTION, + /** + * A multiplication operation. + */ + MULTIPLICATION, + /** + * A division operation. + */ + DIVISION, + /** + * A sine operation. + */ + SIN, + /** + * A cosine operation. + */ + COS, + /** + * An exponential function. + */ + EXPONENTIAL, + /** + * A logarithm function. + */ + LOGARITHM, + /** + * A power operation. + */ + POWER, + /** + * An absolute value. + */ + ABSOLUTE + } + + /** + * The type of the expression. + */ + public final DualExpr.ExprType type; + + /** + * The dual number value, for constants. + */ + public Dual number; + + /** + * The left (or first) part of the expression. + */ + public DualExpr left; + /** + * The right (or second) part of the expression. + */ + public DualExpr right; + + /** + * The power to use, for power operations. + */ + public int power; + + /** + * Create a new constant dual number. + * + * @param num + * The value of the dual number. + */ + public DualExpr(Dual num) { + this.type = ExprType.CONSTANT; + + number = num; + } + + /** + * Create a new unary dual number. + * + * @param type + * The type of operation to perform. + * @param val + * The parameter to the value. + */ + public DualExpr(DualExpr.ExprType type, DualExpr val) { + this.type = type; + + left = val; + } + + /** + * Create a new binary dual number. + * + * @param type + * The type of operation to perform. + * @param val + * The parameter to the value. + */ + public DualExpr(DualExpr.ExprType type, DualExpr left, DualExpr right) { + this.type = type; + + this.left = left; + this.right = right; + } + + /** + * Create a new power expression. + * + * @param left + * The expression to raise. + * @param power + * The power to raise it by. + */ + public DualExpr(DualExpr left, int power) { + this.type = ExprType.POWER; + + this.left = left; + this.power = power; + } + + /** + * Evaluate an expression to a number. + * + * Uses the rules provided in + * https://en.wikipedia.org/wiki/Automatic_differentiation + * + * @return The evaluated expression. + */ + public Dual evaluate() { + /* The evaluated dual numbers. */ + Dual lval, rval; + + /* Perform the right operation for each type. */ + switch (type) { + case CONSTANT: + return number; + case ADDITION: + lval = left.evaluate(); + rval = right.evaluate(); + + return new Dual(lval.real + rval.real, lval.dual + rval.dual); + case SUBTRACTION: + lval = left.evaluate(); + rval = right.evaluate(); + + return new Dual(lval.real - rval.real, lval.dual - rval.dual); + case MULTIPLICATION: + lval = left.evaluate(); + rval = right.evaluate(); + + { + double lft = lval.dual * rval.real; + double rght = lval.real * rval.dual; + + return new Dual(lval.real * rval.real, lft + rght); + } + case DIVISION: + lval = left.evaluate(); + rval = right.evaluate(); + + { + if (rval.real == 0) { + throw new IllegalArgumentException("ERROR: Attempted to divide by zero."); + } + + double lft = lval.dual * rval.real; + double rght = lval.real * rval.dual; + + double val = (lft - rght) / (rval.real * rval.real); + + return new Dual(lval.real / rval.real, val); + } + case SIN: + lval = left.evaluate(); + + return new Dual(Math.sin(lval.real), lval.dual * Math.cos(lval.real)); + case COS: + lval = left.evaluate(); + + return new Dual(Math.cos(lval.real), -lval.dual * Math.sin(lval.real)); + case EXPONENTIAL: + lval = left.evaluate(); + + { + double val = Math.exp(lval.real); + + return new Dual(val, lval.dual * val); + } + case LOGARITHM: + lval = left.evaluate(); + + if (lval.real <= 0) { + throw new IllegalArgumentException( + "ERROR: Attempted to take non-positive log."); + } + + return new Dual(Math.log(lval.real), lval.dual / lval.real); + case POWER: + lval = left.evaluate(); + + if (lval.real == 0) { + throw new IllegalArgumentException("ERROR: Raising zero to a power."); + } + + { + double rl = Math.pow(lval.real, power); + + double lft = Math.pow(lval.real, power - 1); + + return new Dual(rl, power * lft * lval.dual); + } + case ABSOLUTE: + lval = left.evaluate(); + + return new Dual(Math.abs(lval.real), lval.dual * Math.signum(lval.real)); + default: + String msg = "ERROR: Unknown expression type %s"; + + throw new IllegalArgumentException(String.format(msg, type)); + } + } + + @Override + public String toString() { + switch (type) { + case ABSOLUTE: + return String.format("abs(%s)", left.toString()); + case ADDITION: + return String.format("(%s + %s)", left.toString(), right.toString()); + case CONSTANT: + return String.format("%s", number.toString()); + case COS: + return String.format("cos(%s)", left.toString()); + case DIVISION: + return String.format("(%s / %s)", left.toString(), right.toString()); + case EXPONENTIAL: + return String.format("exp(%s)", left.toString()); + case LOGARITHM: + return String.format("log(%s)", left.toString()); + case MULTIPLICATION: + return String.format("(%s * %s)", left.toString(), right.toString()); + case POWER: + return String.format("(%s ^ %d)", left.toString(), power); + case SIN: + return String.format("sin(%s)", left.toString()); + case SUBTRACTION: + return String.format("(%s - %s)", left.toString(), right.toString()); + default: + return String.format("UNKNOWN_EXPR"); + } + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + + result = prime * result + ((left == null) ? 0 : left.hashCode()); + result = prime * result + ((number == null) ? 0 : number.hashCode()); + result = prime * result + power; + result = prime * result + ((right == null) ? 0 : right.hashCode()); + result = prime * result + ((type == null) ? 0 : type.hashCode()); + + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + + DualExpr other = (DualExpr) obj; + + if (type != other.type) { + return false; + } + + if (left == null) { + if (other.left != null) { + return false; + } + } else if (!left.equals(other.left)) { + return false; + } + + if (number == null) { + if (other.number != null) + return false; + } else if (!number.equals(other.number)) { + return false; + } + + if (power != other.power) { + return false; + } + + if (right == null) { + if (other.right != null) { + return false; + } + } else if (!right.equals(other.right)) { + return false; + } + + return true; + } +} \ No newline at end of file -- cgit v1.2.3