Upgrade origin-src to google transit feed 1.2.6
[bus.git] / origin-src / transitfeed-1.2.6 / test / testmerge.py
blob:a/origin-src/transitfeed-1.2.6/test/testmerge.py -> blob:b/origin-src/transitfeed-1.2.6/test/testmerge.py
--- a/origin-src/transitfeed-1.2.6/test/testmerge.py
+++ b/origin-src/transitfeed-1.2.6/test/testmerge.py
@@ -1,1 +1,1535 @@
-
+#!/usr/bin/python2.4
+#
+# Copyright 2007 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for the merge module."""
+
+
+__author__ = 'timothy.stranex@gmail.com (Timothy Stranex)'
+
+
+import merge
+import os.path
+import re
+import StringIO
+import transitfeed
+import unittest
+import util
+import zipfile
+
+
+def CheckAttribs(a, b, attrs, assertEquals):
+  """Checks that the objects a and b have the same values for the attributes
+  given in attrs. These checks are done using the given assert function.
+
+  Args:
+    a: The first object.
+    b: The second object.
+    attrs: The list of attribute names (strings).
+    assertEquals: The assertEquals method from unittest.TestCase.
+  """
+  # For Stop objects (and maybe others in the future) Validate converts some
+  # attributes from string to native type
+  a.Validate()
+  b.Validate()
+  for k in attrs:
+    assertEquals(getattr(a, k), getattr(b, k))
+
+
+def CreateAgency():
+  """Create an transitfeed.Agency object for testing.
+
+  Returns:
+    The agency object.
+  """
+  return transitfeed.Agency(name='agency',
+                            url='http://agency',
+                            timezone='Africa/Johannesburg',
+                            id='agency')
+
+
+class TestingProblemReporter(merge.MergeProblemReporter):
+  def __init__(self, accumulator):
+    merge.MergeProblemReporter.__init__(self, accumulator)
+
+
+class TestingProblemAccumulator(transitfeed.ProblemAccumulatorInterface):
+  """This problem reporter keeps track of all problems.
+
+  Attributes:
+    problems: The list of problems reported.
+  """
+
+  def __init__(self):
+    self.problems = []
+    self._expect_classes = []
+
+  def _Report(self, problem):
+    problem.FormatProblem()  # Shouldn't crash
+    self.problems.append(problem)
+    for problem_class in self._expect_classes:
+      if isinstance(problem, problem_class):
+        return
+    raise problem
+
+  def CheckReported(self, problem_class):
+    """Checks if a problem of the given class was reported.
+
+    Args:
+      problem_class: The problem class, a class inheriting from
+                     MergeProblemWithContext.
+
+    Returns:
+      True if a matching problem was reported.
+    """
+    for problem in self.problems:
+      if isinstance(problem, problem_class):
+        return True
+    return False
+
+  def ExpectProblemClass(self, problem_class):
+    """Supresses exception raising for problems inheriting from this class.
+
+    Args:
+      problem_class: The problem class, a class inheriting from
+                     MergeProblemWithContext.
+    """
+    self._expect_classes.append(problem_class)
+
+  def assertExpectedProblemsReported(self, testcase):
+    """Asserts that every expected problem class has been reported.
+
+    The assertions are done using the assert_ method of the testcase.
+
+    Args:
+      testcase: The unittest.TestCase instance.
+    """
+    for problem_class in self._expect_classes:
+      testcase.assert_(self.CheckReported(problem_class))
+
+
+class TestApproximateDistanceBetweenPoints(util.TestCase):
+
+  def _assertWithinEpsilon(self, a, b, epsilon=1.0):
+    """Asserts that a and b are equal to within an epsilon.
+
+    Args:
+      a: The first value (float).
+      b: The second value (float).
+      epsilon: The epsilon value (float).
+    """
+    self.assert_(abs(a-b) < epsilon)
+
+  def testDegenerate(self):
+    p = (30.0, 30.0)
+    self._assertWithinEpsilon(
+        merge.ApproximateDistanceBetweenPoints(p, p), 0.0)
+
+  def testFar(self):
+    p1 = (30.0, 30.0)
+    p2 = (40.0, 40.0)
+    self.assert_(merge.ApproximateDistanceBetweenPoints(p1, p2) > 1e4)
+
+
+class TestSchemedMerge(util.TestCase):
+
+  class TestEntity:
+    """A mock entity (like Route or Stop) for testing."""
+
+    def __init__(self, x, y, z):
+      self.x = x
+      self.y = y
+      self.z = z
+
+  def setUp(self):
+    a_schedule = transitfeed.Schedule()
+    b_schedule = transitfeed.Schedule()
+    merged_schedule = transitfeed.Schedule()
+    accumulator = TestingProblemAccumulator()
+    self.fm = merge.FeedMerger(a_schedule, b_schedule,
+                               merged_schedule,
+                               TestingProblemReporter(accumulator))
+    self.ds = merge.DataSetMerger(self.fm)
+
+    def Migrate(ent, sched, newid):
+      """A migration function for the mock entity."""
+      return self.TestEntity(ent.x, ent.y, ent.z)
+    self.ds._Migrate = Migrate
+
+  def testMergeIdentical(self):
+    class TestAttrib:
+      """An object that is equal to everything."""
+
+      def __cmp__(self, b):
+        return 0
+
+    x = 99
+    a = TestAttrib()
+    b = TestAttrib()
+
+    self.assert_(self.ds._MergeIdentical(x, x) == x)
+    self.assert_(self.ds._MergeIdentical(a, b) is b)
+    self.assertRaises(merge.MergeError, self.ds._MergeIdentical, 1, 2)
+
+  def testMergeIdenticalCaseInsensitive(self):
+    self.assert_(self.ds._MergeIdenticalCaseInsensitive('abc', 'ABC') == 'ABC')
+    self.assert_(self.ds._MergeIdenticalCaseInsensitive('abc', 'AbC') == 'AbC')
+    self.assertRaises(merge.MergeError,
+                      self.ds._MergeIdenticalCaseInsensitive, 'abc', 'bcd')
+    self.assertRaises(merge.MergeError,
+                      self.ds._MergeIdenticalCaseInsensitive, 'abc', 'ABCD')
+
+  def testMergeOptional(self):
+    x = 99
+    y = 100
+
+    self.assertEquals(self.ds._MergeOptional(None, None), None)
+    self.assertEquals(self.ds._MergeOptional(None, x), x)
+    self.assertEquals(self.ds._MergeOptional(x, None), x)
+    self.assertEquals(self.ds._MergeOptional(x, x), x)
+    self.assertRaises(merge.MergeError, self.ds._MergeOptional, x, y)
+
+  def testMergeSameAgency(self):
+    kwargs = {'name': 'xxx',
+              'agency_url': 'http://www.example.com',
+              'agency_timezone': 'Europe/Zurich'}
+    id1 = 'agency1'
+    id2 = 'agency2'
+    id3 = 'agency3'
+    id4 = 'agency4'
+    id5 = 'agency5'
+
+    a = self.fm.a_schedule.NewDefaultAgency(id=id1, **kwargs)
+    b = self.fm.b_schedule.NewDefaultAgency(id=id2, **kwargs)
+    c = transitfeed.Agency(id=id3, **kwargs)
+    self.fm.merged_schedule.AddAgencyObject(c)
+    self.fm.Register(a, b, c)
+
+    d = transitfeed.Agency(id=id4, **kwargs)
+    e = transitfeed.Agency(id=id5, **kwargs)
+    self.fm.a_schedule.AddAgencyObject(d)
+    self.fm.merged_schedule.AddAgencyObject(e)
+    self.fm.Register(d, None, e)
+
+    self.assertEquals(self.ds._MergeSameAgency(id1, id2), id3)
+    self.assertEquals(self.ds._MergeSameAgency(None, None), id3)
+    self.assertEquals(self.ds._MergeSameAgency(id1, None), id3)
+    self.assertEquals(self.ds._MergeSameAgency(None, id2), id3)
+
+    # id1 is not a valid agency_id in the new schedule so it cannot be merged
+    self.assertRaises(KeyError, self.ds._MergeSameAgency, id1, id1)
+
+    # this fails because d (id4) and b (id2) don't map to the same agency
+    # in the merged schedule
+    self.assertRaises(merge.MergeError, self.ds._MergeSameAgency, id4, id2)
+
+  def testSchemedMerge_Success(self):
+
+    def Merger(a, b):
+      return a + b
+
+    scheme = {'x': Merger, 'y': Merger, 'z': Merger}
+    a = self.TestEntity(1, 2, 3)
+    b = self.TestEntity(4, 5, 6)
+    c = self.ds._SchemedMerge(scheme, a, b)
+
+    self.assertEquals(c.x, 5)
+    self.assertEquals(c.y, 7)
+    self.assertEquals(c.z, 9)
+
+  def testSchemedMerge_Failure(self):
+
+    def Merger(a, b):
+      raise merge.MergeError()
+
+    scheme = {'x': Merger, 'y': Merger, 'z': Merger}
+    a = self.TestEntity(1, 2, 3)
+    b = self.TestEntity(4, 5, 6)
+
+    self.assertRaises(merge.MergeError, self.ds._SchemedMerge,
+                      scheme, a, b)
+
+  def testSchemedMerge_NoNewId(self):
+    class TestDataSetMerger(merge.DataSetMerger):
+      def _Migrate(self, entity, schedule, newid):
+        self.newid = newid
+        return entity
+    dataset_merger = TestDataSetMerger(self.fm)
+    a = self.TestEntity(1, 2, 3)
+    b = self.TestEntity(4, 5, 6)
+    dataset_merger._SchemedMerge({}, a, b)
+    self.assertEquals(dataset_merger.newid, False)
+
+  def testSchemedMerge_ErrorTextContainsAttributeNameAndReason(self):
+    reason = 'my reason'
+    attribute_name = 'long_attribute_name'
+
+    def GoodMerger(a, b):
+      return a + b
+
+    def BadMerger(a, b):
+      raise merge.MergeError(reason)
+
+    a = self.TestEntity(1, 2, 3)
+    setattr(a, attribute_name, 1)
+    b = self.TestEntity(4, 5, 6)
+    setattr(b, attribute_name, 2)
+    scheme = {'x': GoodMerger, 'y': GoodMerger, 'z': GoodMerger,
+              attribute_name: BadMerger}
+
+    try:
+      self.ds._SchemedMerge(scheme, a, b)
+    except merge.MergeError, merge_error:
+      error_text = str(merge_error)
+      self.assert_(reason in error_text)
+      self.assert_(attribute_name in error_text)
+
+
+class TestFeedMerger(util.TestCase):
+
+  class Merger:
+    def __init__(self, test, n, should_fail=False):
+      self.test = test
+      self.n = n
+      self.should_fail = should_fail
+
+    def MergeDataSets(self):
+      self.test.called.append(self.n)
+      return not self.should_fail
+
+  def setUp(self):
+    a_schedule = transitfeed.Schedule()
+    b_schedule = transitfeed.Schedule()
+    merged_schedule = transitfeed.Schedule()
+    accumulator = TestingProblemAccumulator()
+    self.fm = merge.FeedMerger(a_schedule, b_schedule,
+                               merged_schedule,
+                               TestingProblemReporter(accumulator))
+    self.called = []
+
+  def testSequence(self):
+    for i in range(10):
+      self.fm.AddMerger(TestFeedMerger.Merger(self, i))
+    self.assert_(self.fm.MergeSchedules())
+    self.assertEquals(self.called, range(10))
+
+  def testStopsAfterError(self):
+    for i in range(10):
+      self.fm.AddMerger(TestFeedMerger.Merger(self, i, i == 5))
+    self.assert_(not self.fm.MergeSchedules())
+    self.assertEquals(self.called, range(6))
+
+  def testRegister(self):
+    s1 = transitfeed.Stop(stop_id='1')
+    s2 = transitfeed.Stop(stop_id='2')
+    s3 = transitfeed.Stop(stop_id='3')
+    self.fm.Register(s1, s2, s3)
+    self.assertEquals(self.fm.a_merge_map, {s1: s3})
+    self.assertEquals('3', s1._migrated_entity.stop_id)
+    self.assertEquals(self.fm.b_merge_map, {s2: s3})
+    self.assertEquals('3', s2._migrated_entity.stop_id)
+
+  def testRegisterNone(self):
+    s2 = transitfeed.Stop(stop_id='2')
+    s3 = transitfeed.Stop(stop_id='3')
+    self.fm.Register(None, s2, s3)
+    self.assertEquals(self.fm.a_merge_map, {})
+    self.assertEquals(self.fm.b_merge_map, {s2: s3})
+    self.assertEquals('3', s2._migrated_entity.stop_id)
+
+  def testGenerateId_Prefix(self):
+    x = 'test'
+    a = self.fm.GenerateId(x)
+    b = self.fm.GenerateId(x)
+    self.assertNotEqual(a, b)
+    self.assert_(a.startswith(x))
+    self.assert_(b.startswith(x))
+
+  def testGenerateId_None(self):
+    a = self.fm.GenerateId(None)
+    b = self.fm.GenerateId(None)
+    self.assertNotEqual(a, b)
+
+  def testGenerateId_InitialCounter(self):
+    a_schedule = transitfeed.Schedule()
+    b_schedule = transitfeed.Schedule()
+    merged_schedule = transitfeed.Schedule()
+
+    for i in range(10):
+      agency = transitfeed.Agency(name='agency', url='http://agency',
+                                  timezone='Africa/Johannesburg',
+                                  id='agency_%d' % i)
+      if i % 2:
+        b_schedule.AddAgencyObject(agency)
+      else:
+        a_schedule.AddAgencyObject(agency)
+    accumulator = TestingProblemAccumulator()
+    feed_merger = merge.FeedMerger(a_schedule, b_schedule,
+                                   merged_schedule,
+                                   TestingProblemReporter(accumulator))
+
+    # check that the postfix number of any generated ids are greater than
+    # the postfix numbers of any ids in the old and new schedules
+    gen_id = feed_merger.GenerateId(None)
+    postfix_num = int(gen_id[gen_id.rfind('_')+1:])
+    self.assert_(postfix_num >= 10)
+
+  def testGetMerger(self):
+    class MergerA(merge.DataSetMerger):
+      pass
+
+    class MergerB(merge.DataSetMerger):
+      pass
+
+    a = MergerA(self.fm)
+    b = MergerB(self.fm)
+
+    self.fm.AddMerger(a)
+    self.fm.AddMerger(b)
+
+    self.assertEquals(self.fm.GetMerger(MergerA), a)
+    self.assertEquals(self.fm.GetMerger(MergerB), b)
+
+  def testGetMerger_Error(self):
+    self.assertRaises(LookupError, self.fm.GetMerger, TestFeedMerger.Merger)
+
+
+class TestServicePeriodMerger(util.TestCase):
+
+  def setUp(self):
+    a_schedule = transitfeed.Schedule()
+    b_schedule = transitfeed.Schedule()
+    merged_schedule = transitfeed.Schedule()
+    self.accumulator = TestingProblemAccumulator()
+    self.problem_reporter = TestingProblemReporter(self.accumulator)
+    self.fm = merge.FeedMerger(a_schedule, b_schedule, merged_schedule,
+                               self.problem_reporter)
+    self.spm = merge.ServicePeriodMerger(self.fm)
+    self.fm.AddMerger(self.spm)
+
+  def _AddTwoPeriods(self, start1, end1, start2, end2):
+    sp1fields = ['test1', start1, end1] + ['1']*7
+    self.sp1 = transitfeed.ServicePeriod(field_list=sp1fields)
+    sp2fields = ['test2', start2, end2] + ['1']*7
+    self.sp2 = transitfeed.ServicePeriod(field_list=sp2fields)
+
+    self.fm.a_schedule.AddServicePeriodObject(self.sp1)
+    self.fm.b_schedule.AddServicePeriodObject(self.sp2)
+
+  def testCheckDisjoint_True(self):
+    self._AddTwoPeriods('20071213', '20071231',
+                        '20080101', '20080201')
+    self.assert_(self.spm.CheckDisjointCalendars())
+
+  def testCheckDisjoint_False1(self):
+    self._AddTwoPeriods('20071213', '20080201',
+                        '20080101', '20080301')
+    self.assert_(not self.spm.CheckDisjointCalendars())
+
+  def testCheckDisjoint_False2(self):
+    self._AddTwoPeriods('20080101', '20090101',
+                        '20070101', '20080601')
+    self.assert_(not self.spm.CheckDisjointCalendars())
+
+  def testCheckDisjoint_False3(self):
+    self._AddTwoPeriods('20080301', '20080901',
+                        '20080101', '20090101')
+    self.assert_(not self.spm.CheckDisjointCalendars())
+
+  def testDisjoinCalendars(self):
+    self._AddTwoPeriods('20071213', '20080201',
+                        '20080101', '20080301')
+    self.spm.DisjoinCalendars('20080101')
+    self.assertEquals(self.sp1.start_date, '20071213')
+    self.assertEquals(self.sp1.end_date, '20071231')
+    self.assertEquals(self.sp2.start_date, '20080101')
+    self.assertEquals(self.sp2.end_date, '20080301')
+
+  def testDisjoinCalendars_Dates(self):
+    self._AddTwoPeriods('20071213', '20080201',
+                        '20080101', '20080301')
+    self.sp1.SetDateHasService('20071201')
+    self.sp1.SetDateHasService('20081231')
+    self.sp2.SetDateHasService('20071201')
+    self.sp2.SetDateHasService('20081231')
+
+    self.spm.DisjoinCalendars('20080101')
+
+    self.assert_('20071201' in self.sp1.date_exceptions.keys())
+    self.assert_('20081231' not in self.sp1.date_exceptions.keys())
+    self.assert_('20071201' not in self.sp2.date_exceptions.keys())
+    self.assert_('20081231' in self.sp2.date_exceptions.keys())
+
+  def testUnion(self):
+    self._AddTwoPeriods('20071213', '20071231',
+                        '20080101', '20080201')
+    self.accumulator.ExpectProblemClass(merge.MergeNotImplemented)
+    self.fm.MergeSchedules()
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetServicePeriodList()), 2)
+
+    # make fields a copy of the service period attributes except service_id
+    fields = list(transitfeed.ServicePeriod._DAYS_OF_WEEK)
+    fields += ['start_date', 'end_date']
+
+    # now check that these attributes are preserved in the merge
+    CheckAttribs(self.sp1, self.fm.a_merge_map[self.sp1], fields,
+                 self.assertEquals)
+    CheckAttribs(self.sp2, self.fm.b_merge_map[self.sp2], fields,
+                 self.assertEquals)
+
+    self.accumulator.assertExpectedProblemsReported(self)
+
+  def testMerge_RequiredButNotDisjoint(self):
+    self._AddTwoPeriods('20070101', '20090101',
+                        '20080101', '20100101')
+    self.accumulator.ExpectProblemClass(merge.CalendarsNotDisjoint)
+    self.assertEquals(self.spm.MergeDataSets(), False)
+    self.accumulator.assertExpectedProblemsReported(self)
+
+  def testMerge_NotRequiredAndNotDisjoint(self):
+    self._AddTwoPeriods('20070101', '20090101',
+                        '20080101', '20100101')
+    self.spm.require_disjoint_calendars = False
+    self.accumulator.ExpectProblemClass(merge.MergeNotImplemented)
+    self.fm.MergeSchedules()
+    self.accumulator.assertExpectedProblemsReported(self)
+
+
+class TestAgencyMerger(util.TestCase):
+
+  def setUp(self):
+    a_schedule = transitfeed.Schedule()
+    b_schedule = transitfeed.Schedule()
+    merged_schedule = transitfeed.Schedule()
+    self.accumulator = TestingProblemAccumulator()
+    self.problem_reporter = TestingProblemReporter(self.accumulator)
+    self.fm = merge.FeedMerger(a_schedule, b_schedule, merged_schedule,
+                               self.problem_reporter)
+    self.am = merge.AgencyMerger(self.fm)
+    self.fm.AddMerger(self.am)
+
+    self.a1 = transitfeed.Agency(id='a1', agency_name='a1',
+                                 agency_url='http://www.a1.com',
+                                 agency_timezone='Africa/Johannesburg',
+                                 agency_phone='123 456 78 90')
+    self.a2 = transitfeed.Agency(id='a2', agency_name='a1',
+                                 agency_url='http://www.a1.com',
+                                 agency_timezone='Africa/Johannesburg',
+                                 agency_phone='789 65 43 21')
+
+  def testMerge(self):
+    self.a2.agency_id = self.a1.agency_id
+    self.fm.a_schedule.AddAgencyObject(self.a1)
+    self.fm.b_schedule.AddAgencyObject(self.a2)
+    self.fm.MergeSchedules()
+
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetAgencyList()), 1)
+    self.assertEquals(merged_schedule.GetAgencyList()[0],
+                      self.fm.a_merge_map[self.a1])
+    self.assertEquals(self.fm.a_merge_map[self.a1],
+                      self.fm.b_merge_map[self.a2])
+    # differing values such as agency_phone should be taken from self.a2
+    self.assertEquals(merged_schedule.GetAgencyList()[0], self.a2)
+    self.assertEquals(self.am.GetMergeStats(), (1, 0, 0))
+
+    # check that id is preserved
+    self.assertEquals(self.fm.a_merge_map[self.a1].agency_id,
+                      self.a1.agency_id)
+
+  def testNoMerge_DifferentId(self):
+    self.fm.a_schedule.AddAgencyObject(self.a1)
+    self.fm.b_schedule.AddAgencyObject(self.a2)
+    self.fm.MergeSchedules()
+
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetAgencyList()), 2)
+
+    self.assert_(self.fm.a_merge_map[self.a1] in
+                 merged_schedule.GetAgencyList())
+    self.assert_(self.fm.b_merge_map[self.a2] in
+                 merged_schedule.GetAgencyList())
+    self.assertEquals(self.a1, self.fm.a_merge_map[self.a1])
+    self.assertEquals(self.a2, self.fm.b_merge_map[self.a2])
+    self.assertEquals(self.am.GetMergeStats(), (0, 1, 1))
+
+    # check that the ids are preserved
+    self.assertEquals(self.fm.a_merge_map[self.a1].agency_id,
+                      self.a1.agency_id)
+    self.assertEquals(self.fm.b_merge_map[self.a2].agency_id,
+                      self.a2.agency_id)
+
+  def testNoMerge_SameId(self):
+    # Force a1.agency_id to be unicode to make sure it is correctly encoded
+    # to utf-8 before concatinating to the agency_name containing non-ascii
+    # characters.
+    self.a1.agency_id = unicode(self.a1.agency_id)
+    self.a2.agency_id = str(self.a1.agency_id)
+    self.a2.agency_name = 'different \xc3\xa9'
+    self.fm.a_schedule.AddAgencyObject(self.a1)
+    self.fm.b_schedule.AddAgencyObject(self.a2)
+
+    self.accumulator.ExpectProblemClass(merge.SameIdButNotMerged)
+    self.fm.MergeSchedules()
+
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetAgencyList()), 2)
+    self.assertEquals(self.am.GetMergeStats(), (0, 1, 1))
+
+    # check that the merged entities have different ids
+    self.assertNotEqual(self.fm.a_merge_map[self.a1].agency_id,
+                        self.fm.b_merge_map[self.a2].agency_id)
+
+    self.accumulator.assertExpectedProblemsReported(self)
+
+
+class TestStopMerger(util.TestCase):
+
+  def setUp(self):
+    a_schedule = transitfeed.Schedule()
+    b_schedule = transitfeed.Schedule()
+    merged_schedule = transitfeed.Schedule()
+    self.accumulator = TestingProblemAccumulator()
+    self.problem_reporter = TestingProblemReporter(self.accumulator)
+    self.fm = merge.FeedMerger(a_schedule, b_schedule, merged_schedule,
+                               self.problem_reporter)
+    self.sm = merge.StopMerger(self.fm)
+    self.fm.AddMerger(self.sm)
+
+    self.s1 = transitfeed.Stop(30.0, 30.0,
+                               u'Andr\202' , 's1')
+    self.s1.stop_desc = 'stop 1'
+    self.s1.stop_url = 'http://stop/1'
+    self.s1.zone_id = 'zone1'
+    self.s2 = transitfeed.Stop(30.0, 30.0, 's2', 's2')
+    self.s2.stop_desc = 'stop 2'
+    self.s2.stop_url = 'http://stop/2'
+    self.s2.zone_id = 'zone1'
+
+  def testMerge(self):
+    self.s2.stop_id = self.s1.stop_id
+    self.s2.stop_name = self.s1.stop_name
+    self.s1.location_type = 1
+    self.s2.location_type = 1
+
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.fm.MergeSchedules()
+
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetStopList()), 1)
+    self.assertEquals(merged_schedule.GetStopList()[0],
+                      self.fm.a_merge_map[self.s1])
+    self.assertEquals(self.fm.a_merge_map[self.s1],
+                      self.fm.b_merge_map[self.s2])
+    self.assertEquals(self.sm.GetMergeStats(), (1, 0, 0))
+
+    # check that the remaining attributes are taken from the new stop
+    fields = ['stop_name', 'stop_lat', 'stop_lon', 'stop_desc', 'stop_url',
+              'location_type']
+    CheckAttribs(self.fm.a_merge_map[self.s1], self.s2, fields,
+                 self.assertEquals)
+
+    # check that the id is preserved
+    self.assertEquals(self.fm.a_merge_map[self.s1].stop_id, self.s1.stop_id)
+
+    # check that the zone_id is preserved
+    self.assertEquals(self.fm.a_merge_map[self.s1].zone_id, self.s1.zone_id)
+
+  def testNoMerge_DifferentId(self):
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.fm.MergeSchedules()
+
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetStopList()), 2)
+    self.assert_(self.fm.a_merge_map[self.s1] in merged_schedule.GetStopList())
+    self.assert_(self.fm.b_merge_map[self.s2] in merged_schedule.GetStopList())
+    self.assertEquals(self.sm.GetMergeStats(), (0, 1, 1))
+
+  def testNoMerge_DifferentName(self):
+    self.s2.stop_id = self.s1.stop_id
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.accumulator.ExpectProblemClass(merge.SameIdButNotMerged)
+    self.fm.MergeSchedules()
+
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetStopList()), 2)
+    self.assert_(self.fm.a_merge_map[self.s1] in merged_schedule.GetStopList())
+    self.assert_(self.fm.b_merge_map[self.s2] in merged_schedule.GetStopList())
+    self.assertEquals(self.sm.GetMergeStats(), (0, 1, 1))
+
+  def testNoMerge_FarApart(self):
+    self.s2.stop_id = self.s1.stop_id
+    self.s2.stop_name = self.s1.stop_name
+    self.s2.stop_lat = 40.0
+    self.s2.stop_lon = 40.0
+
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.accumulator.ExpectProblemClass(merge.SameIdButNotMerged)
+    self.fm.MergeSchedules()
+
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetStopList()), 2)
+    self.assert_(self.fm.a_merge_map[self.s1] in merged_schedule.GetStopList())
+    self.assert_(self.fm.b_merge_map[self.s2] in merged_schedule.GetStopList())
+    self.assertEquals(self.sm.GetMergeStats(), (0, 1, 1))
+
+    # check that the merged ids are different
+    self.assertNotEquals(self.fm.a_merge_map[self.s1].stop_id,
+                         self.fm.b_merge_map[self.s2].stop_id)
+
+    self.accumulator.assertExpectedProblemsReported(self)
+
+  def testMerge_CaseInsensitive(self):
+    self.s2.stop_id = self.s1.stop_id
+    self.s2.stop_name = self.s1.stop_name.upper()
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.fm.MergeSchedules()
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetStopList()), 1)
+    self.assertEquals(self.sm.GetMergeStats(), (1, 0, 0))
+
+  def testNoMerge_ZoneId(self):
+    self.s2.zone_id = 'zone2'
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.fm.MergeSchedules()
+
+    merged_schedule = self.fm.GetMergedSchedule()
+    self.assertEquals(len(merged_schedule.GetStopList()), 2)
+
+    self.assert_(self.s1.zone_id in self.fm.a_zone_map)
+    self.assert_(self.s2.zone_id in self.fm.b_zone_map)
+    self.assertEquals(self.sm.GetMergeStats(), (0, 1, 1))
+
+    # check that the zones are still different
+    self.assertNotEqual(self.fm.a_merge_map[self.s1].zone_id,
+                        self.fm.b_merge_map[self.s2].zone_id)
+
+  def testZoneId_SamePreservation(self):
+    # checks that if the zone_ids of some stops are the same before the
+    # merge, they are still the same after.
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.a_schedule.AddStopObject(self.s2)
+    self.fm.MergeSchedules()
+    self.assertEquals(self.fm.a_merge_map[self.s1].zone_id,
+                      self.fm.a_merge_map[self.s2].zone_id)
+
+  def testZoneId_DifferentSchedules(self):
+    # zone_ids may be the same in different schedules but unless the stops
+    # are merged, they should map to different zone_ids
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.fm.MergeSchedules()
+    self.assertNotEquals(self.fm.a_merge_map[self.s1].zone_id,
+                         self.fm.b_merge_map[self.s2].zone_id)
+
+  def testZoneId_MergePreservation(self):
+    # check that if two stops are merged, the zone mapping is used for all
+    # other stops too
+    self.s2.stop_id = self.s1.stop_id
+    self.s2.stop_name = self.s1.stop_name
+    s3 = transitfeed.Stop(field_dict=self.s1)
+    s3.stop_id = 'different'
+
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.a_schedule.AddStopObject(s3)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.fm.MergeSchedules()
+
+    self.assertEquals(self.fm.a_merge_map[self.s1].zone_id,
+                      self.fm.a_merge_map[s3].zone_id)
+    self.assertEquals(self.fm.a_merge_map[s3].zone_id,
+                      self.fm.b_merge_map[self.s2].zone_id)
+
+  def testMergeStationType(self):
+    self.s2.stop_id = self.s1.stop_id
+    self.s2.stop_name = self.s1.stop_name
+    self.s1.location_type = 1
+    self.s2.location_type = 1
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.fm.MergeSchedules()
+    merged_stops = self.fm.GetMergedSchedule().GetStopList()
+    self.assertEquals(len(merged_stops), 1)
+    self.assertEquals(merged_stops[0].location_type, 1)
+
+  def testMergeDifferentTypes(self):
+    self.s2.stop_id = self.s1.stop_id
+    self.s2.stop_name = self.s1.stop_name
+    self.s2.location_type = 1
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    try:
+      self.fm.MergeSchedules()
+      self.fail("Expecting MergeError")
+    except merge.SameIdButNotMerged, merge_error:
+      self.assertTrue(("%s" % merge_error).find("location_type") != -1)
+
+  def AssertS1ParentIsS2(self):
+    """Assert that the merged s1 has parent s2."""
+    new_s1 = self.s1._migrated_entity
+    new_s2 = self.s2._migrated_entity
+    self.assertEquals(new_s1.parent_station, new_s2.stop_id)
+    self.assertEquals(new_s2.parent_station, None)
+    self.assertEquals(new_s1.location_type, 0)
+    self.assertEquals(new_s2.location_type, 1)
+
+  def testMergeMaintainParentRelationship(self):
+    self.s2.location_type = 1
+    self.s1.parent_station = self.s2.stop_id
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.a_schedule.AddStopObject(self.s2)
+    self.fm.MergeSchedules()
+    self.AssertS1ParentIsS2()
+
+  def testParentRelationshipAfterMerge(self):
+    s3 = transitfeed.Stop(field_dict=self.s1)
+    s3.parent_station = self.s2.stop_id
+    self.s2.location_type = 1
+    self.fm.a_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.fm.b_schedule.AddStopObject(s3)
+    self.fm.MergeSchedules()
+    self.AssertS1ParentIsS2()
+
+  def testParentRelationshipWithNewParentid(self):
+    self.s2.location_type = 1
+    self.s1.parent_station = self.s2.stop_id
+    # s3 will have a stop_id conflict with self.s2 so parent_id of the
+    # migrated self.s1 will need to be updated
+    s3 = transitfeed.Stop(field_dict=self.s2)
+    s3.stop_lat = 45
+    self.fm.a_schedule.AddStopObject(s3)
+    self.fm.b_schedule.AddStopObject(self.s1)
+    self.fm.b_schedule.AddStopObject(self.s2)
+    self.accumulator.ExpectProblemClass(merge.SameIdButNotMerged)
+    self.fm.MergeSchedules()
+    self.assertNotEquals(s3._migrated_entity.stop_id,
+                         self.s2._migrated_entity.stop_id)
+    # Check that s2 got a new id
+    self.assertNotEquals(self.s2.stop_id,
+                         self.s2._migrated_entity.stop_id)
+    self.AssertS1ParentIsS2()
+
+  def _AddStopsApart(self):
+    """Adds two stops to the schedules and returns the distance between them.
+
+    Returns:
+      The distance between the stops in metres, a value greater than zero.
+    """
+    self.s2.stop_id = self.s1.stop_id
+    self.s2.stop_name = self.s1.stop_name
+    self.s2.stop_lat += 1.0e-3
+    self.fm.a_sc