//----------------------------------------------------------------------------- // The symbolic algebra system used to write our constraint equations; // routines to build expressions in software or from a user-provided string, // and to compute the partial derivatives that we'll use when write our // Jacobian matrix. // // Copyright 2008-2013 Jonathan Westhues. //----------------------------------------------------------------------------- #include "solvespace.h" ExprVector ExprVector::From(Expr *x, Expr *y, Expr *z) { ExprVector r = { x, y, z}; return r; } ExprVector ExprVector::From(Vector vn) { ExprVector ve; ve.x = Expr::From(vn.x); ve.y = Expr::From(vn.y); ve.z = Expr::From(vn.z); return ve; } ExprVector ExprVector::From(hParam x, hParam y, hParam z) { ExprVector ve; ve.x = Expr::From(x); ve.y = Expr::From(y); ve.z = Expr::From(z); return ve; } ExprVector ExprVector::From(double x, double y, double z) { ExprVector ve; ve.x = Expr::From(x); ve.y = Expr::From(y); ve.z = Expr::From(z); return ve; } ExprVector ExprVector::Minus(ExprVector b) const { ExprVector r; r.x = x->Minus(b.x); r.y = y->Minus(b.y); r.z = z->Minus(b.z); return r; } ExprVector ExprVector::Plus(ExprVector b) const { ExprVector r; r.x = x->Plus(b.x); r.y = y->Plus(b.y); r.z = z->Plus(b.z); return r; } Expr *ExprVector::Dot(ExprVector b) const { Expr *r; r = x->Times(b.x); r = r->Plus(y->Times(b.y)); r = r->Plus(z->Times(b.z)); return r; } ExprVector ExprVector::Cross(ExprVector b) const { ExprVector r; r.x = (y->Times(b.z))->Minus(z->Times(b.y)); r.y = (z->Times(b.x))->Minus(x->Times(b.z)); r.z = (x->Times(b.y))->Minus(y->Times(b.x)); return r; } ExprVector ExprVector::ScaledBy(Expr *s) const { ExprVector r; r.x = x->Times(s); r.y = y->Times(s); r.z = z->Times(s); return r; } ExprVector ExprVector::WithMagnitude(Expr *s) const { Expr *m = Magnitude(); return ScaledBy(s->Div(m)); } Expr *ExprVector::Magnitude() const { Expr *r; r = x->Square(); r = r->Plus(y->Square()); r = r->Plus(z->Square()); return r->Sqrt(); } Vector ExprVector::Eval() const { Vector r; r.x = x->Eval(); r.y = y->Eval(); r.z = z->Eval(); return r; } ExprQuaternion ExprQuaternion::From(hParam w, hParam vx, hParam vy, hParam vz) { ExprQuaternion q; q.w = Expr::From(w); q.vx = Expr::From(vx); q.vy = Expr::From(vy); q.vz = Expr::From(vz); return q; } ExprQuaternion ExprQuaternion::From(Expr *w, Expr *vx, Expr *vy, Expr *vz) { ExprQuaternion q; q.w = w; q.vx = vx; q.vy = vy; q.vz = vz; return q; } ExprQuaternion ExprQuaternion::From(Quaternion qn) { ExprQuaternion qe; qe.w = Expr::From(qn.w); qe.vx = Expr::From(qn.vx); qe.vy = Expr::From(qn.vy); qe.vz = Expr::From(qn.vz); return qe; } ExprVector ExprQuaternion::RotationU() const { ExprVector u; Expr *two = Expr::From(2); u.x = w->Square(); u.x = (u.x)->Plus(vx->Square()); u.x = (u.x)->Minus(vy->Square()); u.x = (u.x)->Minus(vz->Square()); u.y = two->Times(w->Times(vz)); u.y = (u.y)->Plus(two->Times(vx->Times(vy))); u.z = two->Times(vx->Times(vz)); u.z = (u.z)->Minus(two->Times(w->Times(vy))); return u; } ExprVector ExprQuaternion::RotationV() const { ExprVector v; Expr *two = Expr::From(2); v.x = two->Times(vx->Times(vy)); v.x = (v.x)->Minus(two->Times(w->Times(vz))); v.y = w->Square(); v.y = (v.y)->Minus(vx->Square()); v.y = (v.y)->Plus(vy->Square()); v.y = (v.y)->Minus(vz->Square()); v.z = two->Times(w->Times(vx)); v.z = (v.z)->Plus(two->Times(vy->Times(vz))); return v; } ExprVector ExprQuaternion::RotationN() const { ExprVector n; Expr *two = Expr::From(2); n.x = two->Times( w->Times(vy)); n.x = (n.x)->Plus (two->Times(vx->Times(vz))); n.y = two->Times(vy->Times(vz)); n.y = (n.y)->Minus(two->Times( w->Times(vx))); n.z = w->Square(); n.z = (n.z)->Minus(vx->Square()); n.z = (n.z)->Minus(vy->Square()); n.z = (n.z)->Plus (vz->Square()); return n; } ExprVector ExprQuaternion::Rotate(ExprVector p) const { // Express the point in the new basis return (RotationU().ScaledBy(p.x)).Plus( RotationV().ScaledBy(p.y)).Plus( RotationN().ScaledBy(p.z)); } ExprQuaternion ExprQuaternion::Times(ExprQuaternion b) const { Expr *sa = w, *sb = b.w; ExprVector va = { vx, vy, vz }; ExprVector vb = { b.vx, b.vy, b.vz }; ExprQuaternion r; r.w = (sa->Times(sb))->Minus(va.Dot(vb)); ExprVector vr = vb.ScaledBy(sa).Plus( va.ScaledBy(sb).Plus( va.Cross(vb))); r.vx = vr.x; r.vy = vr.y; r.vz = vr.z; return r; } Expr *ExprQuaternion::Magnitude() const { return ((w ->Square())->Plus( (vx->Square())->Plus( (vy->Square())->Plus( (vz->Square())))))->Sqrt(); } Expr *Expr::From(hParam p) { Expr *r = AllocExpr(); r->op = Op::PARAM; r->parh = p; return r; } Expr *Expr::From(double v) { // Statically allocate common constants. // Note: this is only valid because AllocExpr() uses AllocTemporary(), // and Expr* is never explicitly freed. if(v == 0.0) { static Expr zero(0.0); return &zero; } if(v == 1.0) { static Expr one(1.0); return &one; } if(v == -1.0) { static Expr mone(-1.0); return &mone; } if(v == 0.5) { static Expr half(0.5); return ½ } if(v == -0.5) { static Expr mhalf(-0.5); return &mhalf; } Expr *r = AllocExpr(); r->op = Op::CONSTANT; r->v = v; return r; } Expr *Expr::AnyOp(Op newOp, Expr *b) { Expr *r = AllocExpr(); r->op = newOp; r->a = this; r->b = b; return r; } int Expr::Children() const { switch(op) { case Op::PARAM: case Op::PARAM_PTR: case Op::CONSTANT: return 0; case Op::PLUS: case Op::MINUS: case Op::TIMES: case Op::DIV: return 2; case Op::NEGATE: case Op::SQRT: case Op::SQUARE: case Op::SIN: case Op::COS: case Op::ASIN: case Op::ACOS: return 1; case Op::PAREN: case Op::BINARY_OP: case Op::UNARY_OP: case Op::ALL_RESOLVED: break; } ssassert(false, "Unexpected operation"); } int Expr::Nodes() const { switch(Children()) { case 0: return 1; case 1: return 1 + a->Nodes(); case 2: return 1 + a->Nodes() + b->Nodes(); default: ssassert(false, "Unexpected children count"); } } Expr *Expr::DeepCopy() const { Expr *n = AllocExpr(); *n = *this; int c = n->Children(); if(c > 0) n->a = a->DeepCopy(); if(c > 1) n->b = b->DeepCopy(); return n; } Expr *Expr::DeepCopyWithParamsAsPointers(IdList *firstTry, IdList *thenTry) const { Expr *n = AllocExpr(); if(op == Op::PARAM) { // A param that is referenced by its hParam gets rewritten to go // straight in to the parameter table with a pointer, or simply // into a constant if it's already known. Param *p = firstTry->FindByIdNoOops(parh); if(!p) p = thenTry->FindById(parh); if(p->known) { n->op = Op::CONSTANT; n->v = p->val; } else { n->op = Op::PARAM_PTR; n->parp = p; } return n; } *n = *this; int c = n->Children(); if(c > 0) n->a = a->DeepCopyWithParamsAsPointers(firstTry, thenTry); if(c > 1) n->b = b->DeepCopyWithParamsAsPointers(firstTry, thenTry); return n; } double Expr::Eval() const { switch(op) { case Op::PARAM: return SK.GetParam(parh)->val; case Op::PARAM_PTR: return parp->val; case Op::CONSTANT: return v; case Op::PLUS: return a->Eval() + b->Eval(); case Op::MINUS: return a->Eval() - b->Eval(); case Op::TIMES: return a->Eval() * b->Eval(); case Op::DIV: return a->Eval() / b->Eval(); case Op::NEGATE: return -(a->Eval()); case Op::SQRT: return sqrt(a->Eval()); case Op::SQUARE: { double r = a->Eval(); return r*r; } case Op::SIN: return sin(a->Eval()); case Op::COS: return cos(a->Eval()); case Op::ACOS: return acos(a->Eval()); case Op::ASIN: return asin(a->Eval()); case Op::PAREN: case Op::BINARY_OP: case Op::UNARY_OP: case Op::ALL_RESOLVED: break; } ssassert(false, "Unexpected operation"); } Expr *Expr::PartialWrt(hParam p) const { Expr *da, *db; switch(op) { case Op::PARAM_PTR: return From(p.v == parp->h.v ? 1 : 0); case Op::PARAM: return From(p.v == parh.v ? 1 : 0); case Op::CONSTANT: return From(0.0); case Op::PLUS: return (a->PartialWrt(p))->Plus(b->PartialWrt(p)); case Op::MINUS: return (a->PartialWrt(p))->Minus(b->PartialWrt(p)); case Op::TIMES: da = a->PartialWrt(p); db = b->PartialWrt(p); return (a->Times(db))->Plus(b->Times(da)); case Op::DIV: da = a->PartialWrt(p); db = b->PartialWrt(p); return ((da->Times(b))->Minus(a->Times(db)))->Div(b->Square()); case Op::SQRT: return (From(0.5)->Div(a->Sqrt()))->Times(a->PartialWrt(p)); case Op::SQUARE: return (From(2.0)->Times(a))->Times(a->PartialWrt(p)); case Op::NEGATE: return (a->PartialWrt(p))->Negate(); case Op::SIN: return (a->Cos())->Times(a->PartialWrt(p)); case Op::COS: return ((a->Sin())->Times(a->PartialWrt(p)))->Negate(); case Op::ASIN: return (From(1)->Div((From(1)->Minus(a->Square()))->Sqrt())) ->Times(a->PartialWrt(p)); case Op::ACOS: return (From(-1)->Div((From(1)->Minus(a->Square()))->Sqrt())) ->Times(a->PartialWrt(p)); case Op::PAREN: case Op::BINARY_OP: case Op::UNARY_OP: case Op::ALL_RESOLVED: break; } ssassert(false, "Unexpected operation"); } uint64_t Expr::ParamsUsed() const { uint64_t r = 0; if(op == Op::PARAM) r |= ((uint64_t)1 << (parh.v % 61)); if(op == Op::PARAM_PTR) r |= ((uint64_t)1 << (parp->h.v % 61)); int c = Children(); if(c >= 1) r |= a->ParamsUsed(); if(c >= 2) r |= b->ParamsUsed(); return r; } bool Expr::DependsOn(hParam p) const { if(op == Op::PARAM) return (parh.v == p.v); if(op == Op::PARAM_PTR) return (parp->h.v == p.v); int c = Children(); if(c == 1) return a->DependsOn(p); if(c == 2) return a->DependsOn(p) || b->DependsOn(p); return false; } bool Expr::Tol(double a, double b) { return fabs(a - b) < 0.001; } Expr *Expr::FoldConstants() { Expr *n = AllocExpr(); *n = *this; int c = Children(); if(c >= 1) n->a = a->FoldConstants(); if(c >= 2) n->b = b->FoldConstants(); switch(op) { case Op::PARAM_PTR: case Op::PARAM: case Op::CONSTANT: break; case Op::MINUS: case Op::TIMES: case Op::DIV: case Op::PLUS: // If both ops are known, then we can evaluate immediately if(n->a->op == Op::CONSTANT && n->b->op == Op::CONSTANT) { double nv = n->Eval(); n->op = Op::CONSTANT; n->v = nv; break; } // x + 0 = 0 + x = x if(op == Op::PLUS && n->b->op == Op::CONSTANT && Tol(n->b->v, 0)) { *n = *(n->a); break; } if(op == Op::PLUS && n->a->op == Op::CONSTANT && Tol(n->a->v, 0)) { *n = *(n->b); break; } // 1*x = x*1 = x if(op == Op::TIMES && n->b->op == Op::CONSTANT && Tol(n->b->v, 1)) { *n = *(n->a); break; } if(op == Op::TIMES && n->a->op == Op::CONSTANT && Tol(n->a->v, 1)) { *n = *(n->b); break; } // 0*x = x*0 = 0 if(op == Op::TIMES && n->b->op == Op::CONSTANT && Tol(n->b->v, 0)) { n->op = Op::CONSTANT; n->v = 0; break; } if(op == Op::TIMES && n->a->op == Op::CONSTANT && Tol(n->a->v, 0)) { n->op = Op::CONSTANT; n->v = 0; break; } break; case Op::SQRT: case Op::SQUARE: case Op::NEGATE: case Op::SIN: case Op::COS: case Op::ASIN: case Op::ACOS: if(n->a->op == Op::CONSTANT) { double nv = n->Eval(); n->op = Op::CONSTANT; n->v = nv; } break; case Op::PAREN: case Op::BINARY_OP: case Op::UNARY_OP: case Op::ALL_RESOLVED: ssassert(false, "Unexpected operation"); } return n; } void Expr::Substitute(hParam oldh, hParam newh) { ssassert(op != Op::PARAM_PTR, "Expected an expression that refer to params via handles"); if(op == Op::PARAM && parh.v == oldh.v) { parh = newh; } int c = Children(); if(c >= 1) a->Substitute(oldh, newh); if(c >= 2) b->Substitute(oldh, newh); } //----------------------------------------------------------------------------- // If the expression references only one parameter that appears in pl, then // return that parameter. If no param is referenced, then return NO_PARAMS. // If multiple params are referenced, then return MULTIPLE_PARAMS. //----------------------------------------------------------------------------- const hParam Expr::NO_PARAMS = { 0 }; const hParam Expr::MULTIPLE_PARAMS = { 1 }; hParam Expr::ReferencedParams(ParamList *pl) const { if(op == Op::PARAM) { if(pl->FindByIdNoOops(parh)) { return parh; } else { return NO_PARAMS; } } ssassert(op != Op::PARAM_PTR, "Expected an expression that refer to params via handles"); int c = Children(); if(c == 0) { return NO_PARAMS; } else if(c == 1) { return a->ReferencedParams(pl); } else if(c == 2) { hParam pa, pb; pa = a->ReferencedParams(pl); pb = b->ReferencedParams(pl); if(pa.v == NO_PARAMS.v) { return pb; } else if(pb.v == NO_PARAMS.v) { return pa; } else if(pa.v == pb.v) { return pa; // either, doesn't matter } else { return MULTIPLE_PARAMS; } } else ssassert(false, "Unexpected children count"); } //----------------------------------------------------------------------------- // Routines to pretty-print an expression. Mostly for debugging. //----------------------------------------------------------------------------- std::string Expr::Print() const { char c; switch(op) { case Op::PARAM: return ssprintf("param(%08x)", parh.v); case Op::PARAM_PTR: return ssprintf("param(p%08x)", parp->h.v); case Op::CONSTANT: return ssprintf("%.3f", v); case Op::PLUS: c = '+'; goto p; case Op::MINUS: c = '-'; goto p; case Op::TIMES: c = '*'; goto p; case Op::DIV: c = '/'; goto p; p: return "(" + a->Print() + " " + c + " " + b->Print() + ")"; break; case Op::NEGATE: return "(- " + a->Print() + ")"; case Op::SQRT: return "(sqrt " + a->Print() + ")"; case Op::SQUARE: return "(square " + a->Print() + ")"; case Op::SIN: return "(sin " + a->Print() + ")"; case Op::COS: return "(cos " + a->Print() + ")"; case Op::ASIN: return "(asin " + a->Print() + ")"; case Op::ACOS: return "(acos " + a->Print() + ")"; case Op::PAREN: case Op::BINARY_OP: case Op::UNARY_OP: case Op::ALL_RESOLVED: break; } ssassert(false, "Unexpected operation"); } //----------------------------------------------------------------------------- // A parser; convert a string to an expression. Infix notation, with the // usual shift/reduce approach. I had great hopes for user-entered eq // constraints, but those don't seem very useful, so right now this is just // to provide calculator type functionality wherever numbers are entered. //----------------------------------------------------------------------------- #define MAX_UNPARSED 1024 static Expr *Unparsed[MAX_UNPARSED]; static int UnparsedCnt, UnparsedP; static Expr *Operands[MAX_UNPARSED]; static int OperandsP; static Expr *Operators[MAX_UNPARSED]; static int OperatorsP; static jmp_buf exprjmp; static const char *errors[] = { "operator stack full!", "operator stack empty (get top)", "operator stack empty (pop)", "operand stack full", "operand stack empty", "no token to consume", "end of expression unexpected", "expected: )", "expected expression", "too long", "unknown name", "unexpected characters", }; void Expr::PushOperator(Expr *e) { if(OperatorsP >= MAX_UNPARSED) longjmp(exprjmp, 0); Operators[OperatorsP++] = e; } Expr *Expr::TopOperator() { if(OperatorsP <= 0) longjmp(exprjmp, 1); return Operators[OperatorsP-1]; } Expr *Expr::PopOperator() { if(OperatorsP <= 0) longjmp(exprjmp, 2); return Operators[--OperatorsP]; } void Expr::PushOperand(Expr *e) { if(OperandsP >= MAX_UNPARSED) longjmp(exprjmp, 3); Operands[OperandsP++] = e; } Expr *Expr::PopOperand() { if(OperandsP <= 0) longjmp(exprjmp, 4); return Operands[--OperandsP]; } Expr *Expr::Next() { if(UnparsedP >= UnparsedCnt) return NULL; return Unparsed[UnparsedP]; } void Expr::Consume() { if(UnparsedP >= UnparsedCnt) longjmp(exprjmp, 5); UnparsedP++; } int Expr::Precedence(Expr *e) { if(e->op == Op::ALL_RESOLVED) return -1; // never want to reduce this marker ssassert(e->op == Op::BINARY_OP || e->op == Op::UNARY_OP, "Unexpected operation"); switch(e->c) { case 'q': case 's': case 'c': case 'n': return 30; case '*': case '/': return 20; case '+': case '-': return 10; default: ssassert(false, "Unexpected operator"); } } void Expr::Reduce() { Expr *a, *b; Expr *op = PopOperator(); Expr *n; Op o; switch(op->c) { case '+': o = Op::PLUS; goto c; case '-': o = Op::MINUS; goto c; case '*': o = Op::TIMES; goto c; case '/': o = Op::DIV; goto c; c: b = PopOperand(); a = PopOperand(); n = a->AnyOp(o, b); break; case 'n': n = PopOperand()->Negate(); break; case 'q': n = PopOperand()->Sqrt(); break; case 's': n = (PopOperand()->Times(Expr::From(PI/180)))->Sin(); break; case 'c': n = (PopOperand()->Times(Expr::From(PI/180)))->Cos(); break; default: ssassert(false, "Unexpected operator"); } PushOperand(n); } void Expr::ReduceAndPush(Expr *n) { while(Precedence(n) <= Precedence(TopOperator())) { Reduce(); } PushOperator(n); } void Expr::Parse() { Expr *e = AllocExpr(); e->op = Op::ALL_RESOLVED; PushOperator(e); for(;;) { Expr *n = Next(); if(!n) longjmp(exprjmp, 6); if(n->op == Op::CONSTANT) { PushOperand(n); Consume(); } else if(n->op == Op::PAREN && n->c == '(') { Consume(); Parse(); n = Next(); if(n->op != Op::PAREN || n->c != ')') longjmp(exprjmp, 7); Consume(); } else if(n->op == Op::UNARY_OP) { PushOperator(n); Consume(); continue; } else if(n->op == Op::BINARY_OP && n->c == '-') { // The minus sign is special, because it might be binary or // unary, depending on context. n->op = Op::UNARY_OP; n->c = 'n'; PushOperator(n); Consume(); continue; } else { longjmp(exprjmp, 8); } n = Next(); if(n && n->op == Op::BINARY_OP) { ReduceAndPush(n); Consume(); } else { break; } } while(TopOperator()->op != Op::ALL_RESOLVED) { Reduce(); } PopOperator(); // discard the ALL_RESOLVED marker } void Expr::Lex(const char *in) { while(*in) { if(UnparsedCnt >= MAX_UNPARSED) longjmp(exprjmp, 9); char c = *in; if(isdigit(c) || c == '.') { // A number literal char number[70]; int len = 0; while((isdigit(*in) || *in == '.') && len < 30) { number[len++] = *in; in++; } number[len++] = '\0'; Expr *e = AllocExpr(); e->op = Op::CONSTANT; e->v = atof(number); Unparsed[UnparsedCnt++] = e; } else if(isalpha(c) || c == '_') { char name[70]; int len = 0; while(isforname(*in) && len < 30) { name[len++] = *in; in++; } name[len++] = '\0'; Expr *e = AllocExpr(); if(strcmp(name, "sqrt")==0) { e->op = Op::UNARY_OP; e->c = 'q'; } else if(strcmp(name, "cos")==0) { e->op = Op::UNARY_OP; e->c = 'c'; } else if(strcmp(name, "sin")==0) { e->op = Op::UNARY_OP; e->c = 's'; } else if(strcmp(name, "pi")==0) { e->op = Op::CONSTANT; e->v = PI; } else { longjmp(exprjmp, 10); } Unparsed[UnparsedCnt++] = e; } else if(strchr("+-*/()", c)) { Expr *e = AllocExpr(); e->op = (c == '(' || c == ')') ? Op::PAREN : Op::BINARY_OP; e->c = c; Unparsed[UnparsedCnt++] = e; in++; } else if(isspace(c)) { // Ignore whitespace in++; } else { // This is a lex error. longjmp(exprjmp, 11); } } } Expr *Expr::From(const char *in, bool popUpError) { UnparsedCnt = 0; UnparsedP = 0; OperandsP = 0; OperatorsP = 0; Expr *r; int erridx = setjmp(exprjmp); if(!erridx) { Lex(in); Parse(); r = PopOperand(); } else { dbp("exception: parse/lex error: %s", errors[erridx]); if(popUpError) { Error("Not a valid number or expression: '%s'", in); } return NULL; } return r; }