diff --git a/sorted_set/set.mbt b/sorted_set/set.mbt index 9690cd934..42f2e079e 100644 --- a/sorted_set/set.mbt +++ b/sorted_set/set.mbt @@ -143,6 +143,13 @@ pub fn[V : Compare] SortedSet::union( self : SortedSet[V], src : SortedSet[V], ) -> SortedSet[V] { + fn count(node : Node[V]?) -> Int { + match node { + None => 0 + Some(n) => 1 + count(n.left) + count(n.right) + } + } + fn aux(a : Node[V]?, b : Node[V]?) -> Node[V]? { match (a, b) { (Some(_), None) => a @@ -160,12 +167,7 @@ pub fn[V : Compare] SortedSet::union( let t1 = copy_tree(self.root) let t2 = copy_tree(src.root) let t = aux(t1, t2) - let mut ct = 0 - let ret = { root: t, size: 0 } - // TODO: optimize this. Avoid counting the size of the set. - ret.each(_x => ct = ct + 1) - ret.size = ct - ret + { root: t, size: count(t) } } (Some(_), None) => { root: copy_tree(self.root), size: self.size } (None, Some(_)) => { root: copy_tree(src.root), size: src.size } @@ -293,10 +295,10 @@ pub fn[V : Compare] SortedSet::symmetric_difference( self : SortedSet[V], other : SortedSet[V], ) -> SortedSet[V] { - // TODO: Optimize this function to avoid creating two intermediate sets. - let set1 = self.difference(other) - let set2 = other.difference(self) - set1.union(set2) + let ret = new() + self.each(x => if !other.contains(x) { ret.add(x) }) + other.each(x => if !self.contains(x) { ret.add(x) }) + ret } ///|