So, for those who have not been involved in my earlier question on this topic, I have some C# types to represent Sequences, like in math, and I store a formula for generating the nth-term as
Func NTerm;
Now, I would like to compare one Sequence's nth-term formula to another for mathematical equality. I have determined that practically, this requires a limitation to five simple binary operators: +,-,/,*,%.
With tremendous help from Scott Chamberlain on the OP, I've managed to get here. I'm envisioning some kind of recursive solution, but I'm not quite sure where to proceed. Will keep chugging away. Currently, this handles single binary expressions like x + 5 or 7 * c, well, but doesn't deal with other cases. I think the idea is to start from the end of the List, since I think it parses with last operations at the top of the syntax tree and first ones at the bottom. I haven't yet done extensive testing for this method.
UPDATE
public sealed class LambdaParser : ExpressionVisitor
{
public List<BinaryExpression> Expressions { get; } = new List<BinaryExpression> ();
protected override Expression VisitBinary (BinaryExpression node)
{
Expressions.Add (node);
return base.VisitBinary (node);
}
public bool MathEquals (LambdaParser g)
{
try
{
return MathEquals (Expressions, g.Expressions);
} catch (Exception e)
{
throw new Exception ("MathEquals", e);
}
}
public static bool MathEquals (List<BinaryExpression> f, List<BinaryExpression> g)
{
//handle simple cases
if (ReferenceEquals (f, g))
return true;
if (f == null || g == null)
return false;
if (f.Count == 0 || g.Count == 0)
return false;
try
{
//handle single one element cases
if (f.Count == 1 && g.Count == 1)
{
return MathEquals (f[0], g[0]);
}
} catch (Exception e)
{
throw new Exception ("MathEquals", e);
}
throw new NotImplementedException ("Math Equals");
}
static bool MathEquals (BinaryExpression f, BinaryExpression g)
{
if (ReferenceEquals (f, g))
return true;
if (f == null || g == null)
return false;
if (f.NodeType != g.NodeType)
return false;
try
{
switch (f.NodeType)
{
case ExpressionType.Add:
case ExpressionType.Multiply:
return CompareCommutative (f, g);
case ExpressionType.Subtract:
case ExpressionType.Divide:
case ExpressionType.Modulo:
return CompareNonCommutative (f, g);
default:
throw new NotImplementedException ($"Math Equals {nameof(f)}: {f.NodeType}, {nameof(g)}: {g.NodeType}");
}
} catch (Exception e)
{
throw new Exception ($"MathEquals {nameof(f)}: {f.NodeType}, {nameof(g)}: {g.NodeType}", e);
}
}
static bool IsParam (Expression f)
{
return f.NodeType == ExpressionType.Parameter;
}
static bool IsConstant (Expression f)
{
return f.NodeType == ExpressionType.Constant;
}
static bool CompareCommutative (BinaryExpression f, BinaryExpression g)
{
bool left, right;
try
{
//parse left f to left g and right g
left = CompareParamOrConstant (f.Left, g.Left) || CompareParamOrConstant (f.Left, g.Right);
//parse right f to left g and right g
right = CompareParamOrConstant (f.Right, g.Left) || CompareParamOrConstant (f.Right, g.Right);
return left && right;
} catch (Exception e)
{
throw new Exception ($"CompareCommutative {nameof(f)}: {f.NodeType}, {nameof(g)}: {g.NodeType}", e);
}
}
static bool CompareNonCommutative (BinaryExpression f, BinaryExpression g)
{
bool left, right;
try
{
//compare f left to g left
left = CompareParamOrConstant (f.Left, g.Left);
//compare f right to f right
right = CompareParamOrConstant (f.Right, g.Right);
} catch (Exception e)
{
throw new Exception ($"CompareNonCommutative {nameof(f)}: {f.NodeType}, {nameof(g)}: {g.NodeType}", e);
}
return left && right;
}
static bool CompareParamOrConstant (Expression f, Expression g)
{
var ParamF = IsParam (f);
var ConstantF = IsConstant (f);
if (!(ParamF || ConstantF))
{
throw new ArgumentException ($"{nameof(f)} is neither a param or a constant", $"{nameof(f)}");
}
var ParamG = IsParam (g);
var ConstantG = IsConstant (g);
if (!(ParamG || ConstantG))
{
throw new ArgumentException ($"{nameof(g)} is neither a param or a constant", $"{nameof(g)}");
}
if (ParamF)
{
return ParamG;
}
if (ConstantF)
{
return ConstantG && (f as ConstantExpression).Value.Equals ((g as ConstantExpression).Value);
}
}
}
END_UPDATE
I updated the above code to reflect the changes made (mostly refactoring, but also a slightly different approach) after comments reminded me that I'd ignored the non-commutability of some operators.
How do I extend it for expressions with multiple operators of the five defined above? Instead of just 5 * x, something like 2 * x % 3 - x / 5.
Related
Context
Three classes: MetaParticipant, MetaMovie and MetaPerson
A MetaParticipant has one MetaMovie and one MetaPerson
To fix an issue, I created a IsEqual static method in all three.
For the independent ones MetaMovie and MetaPerson, I used (MetaPerson has the same except with its class instead):
public static System.Linq.Expressions.Expression<Func<MetaMovie, bool>> IsEqual(MetaMovie other)
{
if (other.Id > 0) return m => other.Id == m.Id; // Using '> 0' so it skips the new ones in change tracker to the next identifier
return m => other.MetaSource == m.MetaSource && other.ExternalId == m.ExternalId;
}
So, I would like to write the MetaParticipant.IsEqual method, but ain't able to figure out how.
This method will receive a MetaParticipant that can use its MetaMovie and MetaPerson to call the others.
Issue
Here is the MetaParticipant.Equals that IsEqual shall "replace":
public override bool Equals(object obj)
{
if (obj == null) return false;
if (base.Equals(obj)) return true;
if (obj is not MetaParticipant other) return false;
return Movie.Equals(other.Movie) && Person.Equals(other.Person) && JobTitle == other.JobTitle;
}
And where I am up to for IsEqual:
public static Expression<Func<MetaParticipant, bool>> IsEqual(MetaParticipant other)
{
//var own = new Expression<Func<MetaParticipant, bool>() { return x => x.JobTitle == other.JobTitle; };
var mm = MetaMovie.IsEqual(other.Movie);
var mp = MetaPerson.IsEqual(other.Person);
var body = Expression.AndAlso(
Expression.Invoke(mm, Expression.Parameter(other.Movie.GetType(), "mm")),
Expression.Invoke(mp, Expression.Parameter(other.Person.GetType(), "mp"))
);
//body = Expression.AndAlso(body, );
var lambda = Expression.Lambda<Func<MetaParticipant, bool>>(body, Expression.Parameter(typeof(MetaParticipant)));
return lambda;
//return m => Expression.Invoke(mm, Expression.Variable(m.Movie.GetType())) && m.JobTitle == other.JobTitle;
}
Sorry, there is a bit of garbage I kept so you can see some tries I did.
Consider this snippet (much simplified than the original code):
async IAsyncEnumerable<(DateTime, double)> GetSamplesAsync()
{
// ...
var cbpool = new CallbackPool(
HandleBool: (dt, b) => { },
HandleDouble: (dt, d) =>
{
yield return (dt, d); //not allowed
});
while (await cursor.MoveNextAsync(token))
{
this.Parse(cursor.Current, cbpool);
}
}
private record CallbackPool(
Action<DateTime, bool> HandleBool,
Action<DateTime, double> HandleDouble
);
Then, the below Parse is just a behavior-equivalent of the original.
Random _rnd = new Random();
void Parse(object cursor, CallbackPool cbpool)
{
double d = this._rnd.NextDouble(); //d = [0..1)
if (d >= 0.5)
{
cbpool.HandleDouble(new DateTime(), d);
}
else if (d >= 0.25)
{
cbpool.HandleBool(new DateTime(), d >= 0.4);
}
}
However, I do like the GetSamplesAsync code, but the compiler does not: the yield cannot be used within a lambda.
So, I changed the function as follows, although it became much less readable (and also error-prone):
async IAsyncEnumerable<(DateTime, double)> GetSamplesAsync()
{
// ...
(DateTime, double) pair = default;
bool flag = false;
var cbpool = new CallbackPool(
HandleBool: (dt, b) => { },
HandleDouble: (dt, d) =>
{
pair = (dt, d);
flag = true;
});
while (await cursor.MoveNextAsync(token))
{
this.Parse(cursor.Current, cbpool);
if (flag)
{
yield return pair;
}
flag = false;
}
}
I wonder if there is a better way to solve this kind of pattern.
The external flag/pair is pretty dangerous and unnecessary (and it forces a closure); it seems like this bool could be returned from the Parse method, for example:
await foreach (var item in cursor)
{
if (Parse(item, cbpool, out var result))
yield return result;
}
(everything could also be returned via a value-tuple if you don't like the out)
I need to write operation that count nodes who have two sons that equal each other. I tried to it but i got error that not all code path return a value.
please help i have a test
Thanks.
public static int CountWhoHasTwoSameSons(BinNode<int> Head)
{
if (Head != null)
{
if (IsLeaf(Head))
return 1;
if ((Head.HasLeft() && Head.HasRight()) && (Head.GetRight() == Head.GetLeft()))
return 1 + CountWhoHasTwoSameSons(Head.GetLeft()) + CountWhoHasTwoSameSons(Head.GetRight());
}
}
static void Main(string[] args)
{
BinNode<int> t = new BinNode<int>(3);
BinNode<int> t1 = new BinNode<int>(3);
BinNode<int> t2 = new BinNode<int>(3);
BinNode<int> t3 = new BinNode<int>(3);
BinNode<int> t4 = new BinNode<int>(t,3,t1);
BinNode<int> t5 = new BinNode<int>(t2,3,t3);
BinNode<int> t6 = new BinNode<int>(t4,3,null);
BinNode<int> Head = new BinNode<int>(t6,3,t5);
Console.WriteLine(SumTree(Head));
Console.WriteLine(LeafCounter(Head));
Console.WriteLine(CountWhoHasTwoSameSons(Head));
Console.ReadLine();
}
you need to add a return outside an If statement, the compiler can't work out if this function will return something or not. If you can just add a return statement at the end of the function that returns a 0 it should work. Not the most prefered fix, you should really rewrite the function so the return actually is more than just a way of pleasing the compiler but it should work.
Danny
As your error states, your current function may not have return statements in some instances.
public static int CountWhoHasTwoSameSons(BinNode<int> Head)
{
if (Head == null)
return 0;
if (IsLeaf(Head))
return 1;
if ((Head.HasLeft() && Head.HasRight()) &&
(Head.GetRight() == Head.GetLeft())) // It happens with this if statement!
return 1 + CountWhoHasTwoSameSons(Head.GetLeft()) +
CountWhoHasTwoSameSons(Head.GetRight());
}
The error "not all paths return a value" is correct. If your execution flow gets to the third if statement and it is false, then there is no value returned. Your function is defined to always return an int and that case is not being contemplated.
So, try something like this to fix it:
public static int CountWhoHasTwoSameSons(BinNode<int> Head)
{
if (Head != null)
{
if (IsLeaf(Head))
return 1;
if (Head.HasLeft() && Head.HasRight())
{
if (Head.GetRight().GetValue() == Head.GetLeft().GetValue()))
return 1 + CountWhoHasTwoSameSons(Head.GetLeft()) + CountWhoHasTwoSameSons(Head.GetRight());
}
}
return 0;
}
I have a function that needs to check a datetime against a supplied datetime. My function is shown below (it works fine however I don't like it). The only thing that needs to change is the operator but currently I have a few if, else if's etc & lines of code that have been copied.
I'm sure I'm being stupid and there is a much better way of doing this?
enum DateTimeOperator
{
Equals = 0, GreaterThanOrEqualTo, GreaterThan, LessThan
}
My function
bool DateChecked(DateTime dateCheck DateTimeOperator dateOperator)
{
if(dateOperator == DateTimeOperator.Equals)
{
if (File.GetLastWriteTime(filePath + fileName).Date == dateCheck .Date)
return true;
else
return false;
}
else if(dateOperator == DateTimeOperator.GreaterThanOrEqualTo)
{
if (File.GetLastWriteTime(filePath + fileName).Date >= dateCheck .Date)
return true;
else
return false;
}
else if(dateOperator == DateTimeOperator.LessThan)
{
if (File.GetLastWriteTime(filePath + fileName).Date < dateCheck .Date)
return true;
else
return false;
}
}
Just for fun without if/switch:
Dictionary<DateTimeOperator, Func<DateTime, DateTime, bool>> operatorComparer = new Dictionary<DateTimeOperator, Func<DateTime, DateTime, bool>>
{
{ DateTimeOperator.Equals, (a, b) => a == b },
{ DateTimeOperator.GreaterThanOrEqualTo, (a, b) => a >= b },
{ DateTimeOperator.GreaterThan, (a, b) => a > b },
{ DateTimeOperator.LessThan, (a, b) => a < b }
};
bool DateChecked(DateTime dateCheck, DateTimeOperator dateOperator)
{
//TODO: add a sanity check
return operatorComparer[dateOperator](File.GetLastWriteTime(filePath + fileName).Date, dateCheck .Date);
}
I'd suggest an extension method: you have no need to put any switch or if whenever you want to compare dates using DateTimeOperator:
public static class DateTimeOperatorExtensions {
public static Func<Boolean, DateTime, DateTime> Comparison(this DateTimeOperator operation) {
switch(operation) {
//TODO: implenent other cases: i.e. DateTimeOperator.NotEquals here
DateTimeOperator.Equals:
return (left, right) => left == right;
DateTimeOperator.GreaterThanOrEqualTo:
return (left, right) => left >= right;
DateTimeOperator.LessThan:
return (left, right) => left < right;
default:
return (left, right) => left == right;
}
}
}
...
bool DateChecked(DateTime dateCheck DateTimeOperator dateOperator) {
return dateOperator.Comparison()(dateCheck, File.GetLastWriteTime(filePath + fileName).Date);
}
Here is my version of the above code, which I believe to be simpler and more readable
bool DateChecked(DateTime dateCheck DateTimeOperator dateOperator)
{
var result = false;
var myDate = File.GetLastWriteTime(filePath + fileName).Date;
switch(dateOperator)
{
case DateTimeOperator.Equals:
result = myDate == dateCheck.Date;
break;
case DateTimeOperator.GreaterThanOrEqualTo:
result = myDate >= dateCheck.Date;
break;
case DateTimeOperator.LessThan:
result = myDate < dateCheck.Date;
break;
}
return result;
}
or if you don't like the one return statement
bool DateChecked(DateTime dateCheck DateTimeOperator dateOperator)
{
var myDate = File.GetLastWriteTime(filePath + fileName).Date;
switch(dateOperator)
{
case DateTimeOperator.Equals:
return myDate == dateCheck.Date;
case DateTimeOperator.GreaterThanOrEqualTo:
return myDate >= dateCheck.Date;
case DateTimeOperator.LessThan:
return myDate < dateCheck.Date;
}
}
I'm trying to come up with an elegant way to handle some generated polynomials. Here's the situation we'll focus on (exclusively) for this question:
order is a parameter in generating an nth order polynomial, where n:=order + 1.
i is an integer parameter in the range 0..n
The polynomial has zeros at x_j, where j = 1..n and j ≠ i (it should be clear at this point that StackOverflow needs a new feature or it's present and I don't know it)
The polynomial evaluates to 1 at x_i.
Since this particular code example generates x_1 .. x_n, I'll explain how they're found in the code. The points are evenly spaced x_j = j * elementSize / order apart, where n = order + 1.
I generate a Func<double, double> to evaluate this polynomial¹.
private static Func<double, double> GeneratePsi(double elementSize, int order, int i)
{
if (order < 1)
throw new ArgumentOutOfRangeException("order", "order must be greater than 0.");
if (i < 0)
throw new ArgumentOutOfRangeException("i", "i cannot be less than zero.");
if (i > order)
throw new ArgumentException("i", "i cannot be greater than order");
ParameterExpression xp = Expression.Parameter(typeof(double), "x");
// generate the terms of the factored polynomial in form (x_j - x)
List<Expression> factors = new List<Expression>();
for (int j = 0; j <= order; j++)
{
if (j == i)
continue;
double p = j * elementSize / order;
factors.Add(Expression.Subtract(Expression.Constant(p), xp));
}
// evaluate the result at the point x_i to get scaleInv=1.0/scale.
double xi = i * elementSize / order;
double scaleInv = Enumerable.Range(0, order + 1).Aggregate(0.0, (product, j) => product * (j == i ? 1.0 : (j * elementSize / order - xi)));
/* generate an expression to evaluate
* (x_0 - x) * (x_1 - x) .. (x_n - x) / (x_i - x)
* obviously the term (x_i - x) is cancelled in this result, but included here to make the result clear
*/
Expression expr = factors.Skip(1).Aggregate(factors[0], Expression.Multiply);
// multiplying by scale forces the condition f(x_i)=1
expr = Expression.Multiply(Expression.Constant(1.0 / scaleInv), expr);
Expression<Func<double, double>> lambdaMethod = Expression.Lambda<Func<double, double>>(expr, xp);
return lambdaMethod.Compile();
}
The problem: I also need to evaluate ψ′=dψ/dx. To do this, I can rewrite ψ=scale×(x_0 - x)(x_1 - x)×..×(x_n - x)/(x_i - x) in the form ψ=α_n×x^n + α_n×x^(n-1) + .. + α_1×x + α_0. This gives ψ′=n×α_n×x^(n-1) + (n-1)×α_n×x^(n-2) + .. + 1×α_1.
For computational reasons, we can rewrite the final answer without calls to Math.Pow by writing ψ′=x×(x×(x×(..) - β_2) - β_1) - β_0.
To do all this "trickery" (all very basic algebra), I need a clean way to:
Expand a factored Expression containing ConstantExpression and ParameterExpression leaves and basic mathematical operations (end up either BinaryExpression with the NodeType set to the operation) - the result here can include InvocationExpression elements to the MethodInfo for Math.Pow which we'll handle in a special manner throughout.
Then I take the derivative with respect to some specified ParameterExpression. Terms in the result where the right hand side parameter to an invocation of Math.Pow was the constant 2 are replaced by the ConstantExpression(2) multiplied by what was the left hand side (the invocation of Math.Pow(x,1) is removed). Terms in the result that become zero because they were constant with respect to x are removed.
Then factor out the instances of some specific ParameterExpression where they occur as the left hand side parameter to an invocation of Math.Pow. When the right hand side of the invocation becomes a ConstantExpression with the value 1, we replace the invocation with just the ParameterExpression itself.
¹ In the future, I'd like the method to take a ParameterExpression and return an Expression that evaluates based on that parameter. That way I can aggregate generated functions. I'm not there yet.
² In the future, I hope to release a general library for working with LINQ Expressions as symbolic math.
I wrote the basics of several symbolic math features using the ExpressionVisitor type in .NET 4. It's not perfect, but it looks like the foundation of a viable solution.
Symbolic is a public static class exposing methods like Expand, Simplify, and PartialDerivative
ExpandVisitor is an internal helper type that expands expressions
SimplifyVisitor is an internal helper type that simplifies expressions
DerivativeVisitor is an internal helper type that takes the derivative of an expression
ListPrintVisitor is an internal helper type that converts an Expression to a prefix notation with a Lisp syntax
Symbolic
public static class Symbolic
{
public static Expression Expand(Expression expression)
{
return new ExpandVisitor().Visit(expression);
}
public static Expression Simplify(Expression expression)
{
return new SimplifyVisitor().Visit(expression);
}
public static Expression PartialDerivative(Expression expression, ParameterExpression parameter)
{
bool totalDerivative = false;
return new DerivativeVisitor(parameter, totalDerivative).Visit(expression);
}
public static string ToString(Expression expression)
{
ConstantExpression result = (ConstantExpression)new ListPrintVisitor().Visit(expression);
return result.Value.ToString();
}
}
Expanding expressions with ExpandVisitor
internal class ExpandVisitor : ExpressionVisitor
{
protected override Expression VisitBinary(BinaryExpression node)
{
var left = Visit(node.Left);
var right = Visit(node.Right);
if (node.NodeType == ExpressionType.Multiply)
{
Expression[] leftNodes = GetAddedNodes(left).ToArray();
Expression[] rightNodes = GetAddedNodes(right).ToArray();
var result =
leftNodes
.SelectMany(x => rightNodes.Select(y => Expression.Multiply(x, y)))
.Aggregate((sum, term) => Expression.Add(sum, term));
return result;
}
if (node.Left == left && node.Right == right)
return node;
return Expression.MakeBinary(node.NodeType, left, right, node.IsLiftedToNull, node.Method, node.Conversion);
}
/// <summary>
/// Treats the <paramref name="node"/> as the sum (or difference) of one or more child nodes and returns the
/// the individual addends in the sum.
/// </summary>
private static IEnumerable<Expression> GetAddedNodes(Expression node)
{
BinaryExpression binary = node as BinaryExpression;
if (binary != null)
{
switch (binary.NodeType)
{
case ExpressionType.Add:
foreach (var n in GetAddedNodes(binary.Left))
yield return n;
foreach (var n in GetAddedNodes(binary.Right))
yield return n;
yield break;
case ExpressionType.Subtract:
foreach (var n in GetAddedNodes(binary.Left))
yield return n;
foreach (var n in GetAddedNodes(binary.Right))
yield return Expression.Negate(n);
yield break;
default:
break;
}
}
yield return node;
}
}
Taking a derivative with DerivativeVisitor
internal class DerivativeVisitor : ExpressionVisitor
{
private ParameterExpression _parameter;
private bool _totalDerivative;
public DerivativeVisitor(ParameterExpression parameter, bool totalDerivative)
{
if (_totalDerivative)
throw new NotImplementedException();
_parameter = parameter;
_totalDerivative = totalDerivative;
}
protected override Expression VisitBinary(BinaryExpression node)
{
switch (node.NodeType)
{
case ExpressionType.Add:
case ExpressionType.Subtract:
return Expression.MakeBinary(node.NodeType, Visit(node.Left), Visit(node.Right));
case ExpressionType.Multiply:
return Expression.Add(Expression.Multiply(node.Left, Visit(node.Right)), Expression.Multiply(Visit(node.Left), node.Right));
case ExpressionType.Divide:
return Expression.Divide(Expression.Subtract(Expression.Multiply(Visit(node.Left), node.Right), Expression.Multiply(node.Left, Visit(node.Right))), Expression.Power(node.Right, Expression.Constant(2)));
case ExpressionType.Power:
if (node.Right is ConstantExpression)
{
return Expression.Multiply(node.Right, Expression.Multiply(Visit(node.Left), Expression.Subtract(node.Right, Expression.Constant(1))));
}
else if (node.Left is ConstantExpression)
{
return Expression.Multiply(node, MathExpressions.Log(node.Left));
}
else
{
return Expression.Multiply(node, Expression.Add(
Expression.Multiply(Visit(node.Left), Expression.Divide(node.Right, node.Left)),
Expression.Multiply(Visit(node.Right), MathExpressions.Log(node.Left))
));
}
default:
throw new NotImplementedException();
}
}
protected override Expression VisitConstant(ConstantExpression node)
{
return MathExpressions.Zero;
}
protected override Expression VisitInvocation(InvocationExpression node)
{
MemberExpression memberExpression = node.Expression as MemberExpression;
if (memberExpression != null)
{
var member = memberExpression.Member;
if (member.DeclaringType != typeof(Math))
throw new NotImplementedException();
switch (member.Name)
{
case "Log":
return Expression.Divide(Visit(node.Expression), node.Expression);
case "Log10":
return Expression.Divide(Visit(node.Expression), Expression.Multiply(Expression.Constant(Math.Log(10)), node.Expression));
case "Exp":
case "Sin":
case "Cos":
default:
throw new NotImplementedException();
}
}
throw new NotImplementedException();
}
protected override Expression VisitParameter(ParameterExpression node)
{
if (node == _parameter)
return MathExpressions.One;
return MathExpressions.Zero;
}
}
Simplifying expressions with SimplifyVisitor
internal class SimplifyVisitor : ExpressionVisitor
{
protected override Expression VisitBinary(BinaryExpression node)
{
var left = Visit(node.Left);
var right = Visit(node.Right);
ConstantExpression leftConstant = left as ConstantExpression;
ConstantExpression rightConstant = right as ConstantExpression;
if (leftConstant != null && rightConstant != null
&& (leftConstant.Value is double) && (rightConstant.Value is double))
{
double leftValue = (double)leftConstant.Value;
double rightValue = (double)rightConstant.Value;
switch (node.NodeType)
{
case ExpressionType.Add:
return Expression.Constant(leftValue + rightValue);
case ExpressionType.Subtract:
return Expression.Constant(leftValue - rightValue);
case ExpressionType.Multiply:
return Expression.Constant(leftValue * rightValue);
case ExpressionType.Divide:
return Expression.Constant(leftValue / rightValue);
default:
throw new NotImplementedException();
}
}
switch (node.NodeType)
{
case ExpressionType.Add:
if (IsZero(left))
return right;
if (IsZero(right))
return left;
break;
case ExpressionType.Subtract:
if (IsZero(left))
return Expression.Negate(right);
if (IsZero(right))
return left;
break;
case ExpressionType.Multiply:
if (IsZero(left) || IsZero(right))
return MathExpressions.Zero;
if (IsOne(left))
return right;
if (IsOne(right))
return left;
break;
case ExpressionType.Divide:
if (IsZero(right))
throw new DivideByZeroException();
if (IsZero(left))
return MathExpressions.Zero;
if (IsOne(right))
return left;
break;
default:
throw new NotImplementedException();
}
return Expression.MakeBinary(node.NodeType, left, right);
}
protected override Expression VisitUnary(UnaryExpression node)
{
var operand = Visit(node.Operand);
ConstantExpression operandConstant = operand as ConstantExpression;
if (operandConstant != null && (operandConstant.Value is double))
{
double operandValue = (double)operandConstant.Value;
switch (node.NodeType)
{
case ExpressionType.Negate:
if (operandValue == 0.0)
return MathExpressions.Zero;
return Expression.Constant(-operandValue);
default:
throw new NotImplementedException();
}
}
switch (node.NodeType)
{
case ExpressionType.Negate:
if (operand.NodeType == ExpressionType.Negate)
{
return ((UnaryExpression)operand).Operand;
}
break;
default:
throw new NotImplementedException();
}
return Expression.MakeUnary(node.NodeType, operand, node.Type);
}
private static bool IsZero(Expression expression)
{
ConstantExpression constant = expression as ConstantExpression;
if (constant != null)
{
if (constant.Value.Equals(0.0))
return true;
}
return false;
}
private static bool IsOne(Expression expression)
{
ConstantExpression constant = expression as ConstantExpression;
if (constant != null)
{
if (constant.Value.Equals(1.0))
return true;
}
return false;
}
}
Formatting expressions for display with ListPrintVisitor
internal class ListPrintVisitor : ExpressionVisitor
{
protected override Expression VisitBinary(BinaryExpression node)
{
string op = null;
switch (node.NodeType)
{
case ExpressionType.Add:
op = "+";
break;
case ExpressionType.Subtract:
op = "-";
break;
case ExpressionType.Multiply:
op = "*";
break;
case ExpressionType.Divide:
op = "/";
break;
default:
throw new NotImplementedException();
}
var left = Visit(node.Left);
var right = Visit(node.Right);
string result = string.Format("({0} {1} {2})", op, ((ConstantExpression)left).Value, ((ConstantExpression)right).Value);
return Expression.Constant(result);
}
protected override Expression VisitConstant(ConstantExpression node)
{
if (node.Value is string)
return node;
return Expression.Constant(node.Value.ToString());
}
protected override Expression VisitParameter(ParameterExpression node)
{
return Expression.Constant(node.Name);
}
}
Testing the results
[TestMethod]
public void BasicSymbolicTest()
{
ParameterExpression x = Expression.Parameter(typeof(double), "x");
Expression linear = Expression.Add(Expression.Constant(3.0), x);
Assert.AreEqual("(+ 3 x)", Symbolic.ToString(linear));
Expression quadratic = Expression.Multiply(linear, Expression.Add(Expression.Constant(2.0), x));
Assert.AreEqual("(* (+ 3 x) (+ 2 x))", Symbolic.ToString(quadratic));
Expression expanded = Symbolic.Expand(quadratic);
Assert.AreEqual("(+ (+ (+ (* 3 2) (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(expanded));
Assert.AreEqual("(+ (+ (+ 6 (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(Symbolic.Simplify(expanded)));
Expression derivative = Symbolic.PartialDerivative(expanded, x);
Assert.AreEqual("(+ (+ (+ (+ (* 3 0) (* 0 2)) (+ (* 3 1) (* 0 x))) (+ (* x 0) (* 1 2))) (+ (* x 1) (* 1 x)))", Symbolic.ToString(derivative));
Expression simplified = Symbolic.Simplify(derivative);
Assert.AreEqual("(+ 5 (+ x x))", Symbolic.ToString(simplified));
}