Simplifying Addition and Multiplication in Polynomials

In my own research on Computer Algebra Systems (CAS), I often work some pretty cool data structures šŸ˜Ž. In this post, I want to show off one thatā€™s responsible for simplifying polynomials. But first, I have a little secretā€¦In high-school, I thought polynomials were boring! The hours I spent finding roots dragged on and on. They didnā€™t make pretty graphs like the trig functions. Itā€™s only at this point in my life that I more fully appreciate them.

Simply put, polynomials make the world go round! Computing a transcendental function? Do a Taylor (or Chebyshev) approximation. Rendering letters on a screen? Interpolate points with Bezier Curves (with fun videos here and here). Integrating a function? Do quadrature with Legendre Polynomials. They even show up in places you, at first, might not expect. Alright, alright, letā€™s simplify some polynomials!!

A Basic Tree

Our goal is to simplify the polynomial expression:

\[(-3x) + 2z + x + z + 2x + y^2\]

into

\[y^2 + 3z\]

Doing this by hand is natural. Because addition is associative and commutative we can scan the list of monomials and ā€œgroup the like termsā€, reorganizing the list of additions as needed. From there we realize that \((-3x) + x + 2x\) becomes \((-3 + 1 + 2)x\) which is zero and \(2z + z\) becomes \((2 + 1)z = 3z\).

This seems straight forward, but on a computer, computation is often represented as a tree. In this case:

The Polynomial we're trying to simplify represented as a binary tree

Yikes! All of a sudden, that ā€œsimple scanā€ becomes non-obvious. Lots of great research goes into simplifying these general expressions. To do so, they use sophisticated data structures and algorithms. Fortunately, we have a narrower focus of only polynomials. Letā€™s try to follow our human intuition and really optimize for this specific use case.

Our starting point is this recursive data structure:

class Node:
  pass

class Num(Node):
  value: int

class Var(Node):
  name: str

# A Binary Node!
class Add(Node):
  lhs: Node
  rhs: Node

class Mul(Node):
  lhs: Node
  rhs: Node

With the polynomial in question being constructed like

x = Var("x")
y = Var("y")
z = Var("z")

two = Num(2)
neg_three = Num(-3)

p = Add(
  Add(
    Mul(neg_three, x),
    Mul(two, z)),
  Add(
    Add(x, z),
    Add(
      Mul(two, x),
      Mul(y, y))))

Design guided by human intuition

The ā€œtrickā€ of our human algorithm is treating all the monomials on the same ā€œlevelā€ of a tree and scanning across that level. We then reordered monomials to ā€œgroup like termsā€. Finally, we simplified those ā€œgroupsā€. At each step, we used our intuition to apply associativity, commutativity, and distributivity (ACD) rules to do this. Our goal is to turn this intuition into an implementation.

Letā€™s clarify our plan:

  1. flatten tree
  2. scan tree and group ā€œlike termsā€
  3. simplify coefficients of ā€œlike termsā€

Next, we translate the notion of ā€œflatā€ to code. This isnā€™t obvious, but to get started, letā€™s keep things simple and use a list:

class Add(Node):
-  lhs: Node
-  rhs: Node
+  terms: list[Node]

The insight is that the additions can be performed in any order. The original tree structure nests the additions, prescribing a specific evaluation order and obscuring the polynomial structure. Instead, we keep a list of all the terms added together, flattening the tree. We can decide the order of evaluation later. To be clear, there is a natural order of processing from left to right, but we could choose something else like a parallel reduction. Any order is valid! Regardless, the resulting data structure with all the Adds flattened looks like:

Original Tree with the Add node flattened. The Add is an Array with each element being a monomial.

Now weā€™re ready for step 2 of our simplification procedure (ā€œgroupā€ like terms). In my mind, this is putting monomials which have the same variable and power (i.e. \(x^2\)), into a bucket.

def collect_like_terms(add: Add):

  like_terms: dict[Node, list[Coeff]]
  like_terms = {}
  
  for monomial in add:
    # This helper does:
    # y * 2 => (y, 2)
    # 3 * x**2 => (x**2, 3)
    (term, coeff) = split_coeff(monomial)
    
    # put the coefficients of "like terms"
    # into the same list
    like_terms[term].append(coeff)
  
  return like_terms

Finally, we can simplify the coefficient (transform \((-3x) + x + 2x\) to \((-3 + 1 + 2)x\)):

def simplify(add: Add):
  new_add = []

  like_terms = collect_like_terms(add)

  # for each group
  for (term, coeffs) in like_terms.items():

    # simplify coeffs
    new_coeff = sum(coeffs)
    new_monomial = Mul(new_coeff, term)

    # add it to our new polynomial
    new_add.append(new_monomial)
  
  return Add(new_add)

The asymptotic time complexity is \(O(2n)\) with \(n\) being the number of terms in our polynomial. This is pretty cool! and certainly better than the naive tree. But thereā€™s a problem: in practice, polynomials can become large! To avoid using a lot of memory (and waste computation on redundant expressions), we want to simplify on every time we construct an Add node. This means each incremental addition does a full pass over all terms, which is too costly. Fortunately, we can do better, \(O(2)\) better!!

Leaning into the polynomial

Polynomials have a very specific structure we can take advantage of: All the terms follow the pattern \(c_0x^0 + c_1x^1 + c_2x^2 + \dots + c_nx^n\). We want to group ā€œlike termsā€ (things that have the same power of \(x)\). Ideally, we want \(ax + bx\) factored into \((a+b)x\). Letā€™s unpack the multiplication and sum into two arrays:

class Add(Node):
  terms: list[Node]
+  coeffs: list[int]

With our working example being encoded like:

terms  = [ x, z, x, z, x, y**2]
coeffs = [-3, 2, 1, 1, 2, 1]

Where the term and coefficient that share the same index in the array are multiplied together. Weā€™ve effectively ā€œinlinedā€ the monomial into the Add as well.

Better yet, we can more easily compute the terms that can be ā€œgroupedā€ together, ideally like:

terms  = [x, z, y**2]
coeffs = [0, 3, 1]

Where the coefficient 0 should be dropped.

Our task now is to design an algorithm that takes the expanded term list and generates the condensed term list. How about this:

def fast_simplify(terms, coeffs):
  new_coeffs: dict[Term, Coeff] = {}

  for (term, coeff) in zip(terms, coeffs):
    like_terms[term] += coeff

  return Add(
    list(new_coeffs.key()),
    list(new_coeff.values()))

Not bad, we only have a single scan over the Add instead of two, but we still have a linear simplification algorithm thatā€™s run on every single addition šŸ˜ž. Interestingly, weā€™ve used a dict to construct a set of terms and remember how many of that term itā€™s seen. Such a data structure is called a multiset. Seems pretty useful for simplifying polynomialsā€¦

Hereā€™s the core insight: instead of computing the multiset when optimizing, always store it. In code, this looks like:

class Add(Node):
-  terms: list[Node]
-  coeffs: list[int]
+  terms: dict[Node, int]

Where each monomial is encoded as a key-value pair. Pretty slick! So slick, in fact, that we donā€™t even need a separate simplify function. The addition function is enough:

def add_monomial(lhs: Add, rhs: Node):
  # if rhs is in lhs
  # return the coefficient
  # otherwise return 0
  count = 1 + lhs.terms.get(rhs, 0)
  
  lhs.terms.insert(rhs, count)

Which only requires two hashes of rhs, \(O(2)\) achieved! Itā€™s always simplified! Something else I want to highlight is the .get(rhs, 0) call. This creates the illusion that all possible monomials are contained in the set, but some have a count/coefficient of zero. Very smooth indeed!

Representing Addition with Multiset Operations

Stopping to reflect for a moment, weā€™ve mapped our problem onto an existing data structure. Thatā€™s super useful because we can likely import an existing implementation. To really showcase this, letā€™s think of summing two Add nodes together.

Letā€™s start by writing out our expected inputs and outputs:

x = Var("x")
y = Var("y")
z = Var("z")

# x + 2*y
lhs = {
  x : 1,
  y : 2,
}

# z + 3*x
rhs = {
  z : 1,
  x : 3,
}

# (lhs + rhs) becomes:
#  4*x + 2*y + z
result = {
  x : 4,
  y : 2,
  z: 1,
}

Our observations are:

  1. if both lhs and rhs have a term, include the term by adding the respective coefficients,
  2. if a term exists in only one side, include it unchanged.

This is a multiset union! Not only can we represent simplification as building a multiset, we can represent addition with multiset operations as well!

Customizing the multiset for Algebra

Now that we have some operations defined, letā€™s test this structure out. What happens when we add a sequence of integers?

accumulator = Add()

for data in [1,2,3,4]:
  add_monomial(accumulator, data)

resulting in:

term coefficient
1 1
2 1
3 1
4 1

Considering we know how to add integers, it feels wasteful to use eight values to store a single number. Worse yet, what happen when we add a sequence of zeros (0+0+0)?

expression coefficient
0 3

Weā€™re using two numbers to encode no information. Now, in the abstract case of a multiset, both cases are perfectly fine. Weā€™ve recorded that the set contains three zeros. But letā€™s remind ourselves of our goals: weā€™re modeling algebra! The items in the multiset have a specific interpretation. In this case, itā€™s the addition of all the keys multiplied by their corresponding value. Because we have this domain-specific knowledge, we can optimize away all that redundant information. The trick is to have a dedicated accumulator outside the dict

class Add:
+ accumulator: int
  terms: dict[Node, int]

def add_monomial(lhs: Add, rhs: Node):Ā©
+ if isinstance(rhs, Num):
+   lhs.accumulator += rhs
+   return

  count = 1 + lhs.terms.get(rhs, 0)

  lhs.terms.insert(rhs, count)

Now we always accumulate literal values into this special field, preventing them from being entered into the data structure. I picture this structure as a generalized fused multiply-add.

So far so good, but the typing of terms: dict[Node, int] feels limitedā€¦ What if I need to compute 3/2 + (x / 2)? Hereā€™s our next insight: instead of storing an integer values, store fractions!

+ from fractions import Fraction

class Add:
- accumulator: int
+ accumulator: Fraction
- terms: dict[Node, int]
+ terms: dict[Node, Fraction]

Our Add node is quite expressive!

Multiplication

The multiset design relies on the AC rules of addition for real numbers. Multiplication has these properties as well! Amazingly, we can immediately repurpose our data structure!

- class Add:
+ class Mul:
  constant: Fraction
  terms: dict[Node, Fraction]

- def add_monomial(lhs: Add, rhs: Node):
+ def mul_expr(lhs: Mul, rhs: Node):

But whatā€™s the interpretation of this data structure? For addition, the keys were unique terms and the values were the coefficients. Now, we have repeated multiplication, which sounds like exponentiation: (x * x * x to x ** 3). So the coefficients become powers! Carrying on this thread, negative multiplicity should encode division (1 / y to y ** -1) and fractional values would mean roots (sqrt(x) to x ** 1/2)! A more complicated example would be:

expression multiplicity
y -1/2
x+y 1
\[\frac{x + y}{\sqrt{y}}\]

Whatā€™s still missing?

With just this design, we can model, Addition, Subtraction, Multiplication, Division, Exponentiation, Roots. Everything we need to compactly represent polynomials!! Unfortunately, there are a few gotchaā€™s

Function Inverses

CASs are huge, sophisticated systems, but they are also very subtle. For instance, sqrt(x ** 2) maps to x ** 2 ** 1/2 which feels like it should be xā€¦ but this is wrong, itā€™s abs(x). How should a system detect such an edge case?

Normalization

Strange redundancies creep into the system when composing addition and multiplication. For example, the special constant in a Mul node corresponds to an Add node dict value. Working this out more explicitly:

# x * (y**2)
mul_terms = {Var("x"): 1, Var("y"): 2}
prod0 = Mul(1, mul_terms.copy())

# (2/3) * mul_terms
prod1 = Mul(2/3, mul_terms.copy())

# 1 + prod0 + prod1
acc = Add(1)
add_monomial(acc, prod0)
add_monomial(acc, prod1)

We want to get \(1 + \frac{5}{3} xy^2\)

But instead we get the expression \(1 + \frac{2}{3} xy^2 + xy^2\)

The constant term in prod1 (2/3) obscures the shared components of prod1 and prod0. Unfortunately, we need careful crafted rules to prevent such cases from ever occurring. In this case, we can ā€œunpackā€ the Mul.constant term and move it to Add.terms. Optimizers are filled with these sorts of normalization rules.

Conclusion

Reflecting on our journey, weā€™ve started with a problem and iteratively developed a sophisticated strategy for optimizing it! Our human intuition served us well, but ultimately, we encoded our problem into an existing data structure (i.e. realized our problem can be mapped to a multiset). Lastly, we tweaked that data structure to suit our applicationā€™s needs (extended the dict values to Fractions). Encoding domain-specific knowledge is a powerful tool for optimization indeed.

I hope you enjoyed the ride! Make sure to check out how data structures like this are used in real systems (SymEngine) for addition and multiplication