Upgrade origin-src to google transit feed 1.2.6
[bus.git] / origin-src / transitfeed-1.2.6 / test / testtransitfeed.py
blob:a/origin-src/transitfeed-1.2.6/test/testtransitfeed.py -> blob:b/origin-src/transitfeed-1.2.6/test/testtransitfeed.py
  #!/usr/bin/python2.5
   
  # Copyright (C) 2007 Google Inc.
  #
  # Licensed under the Apache License, Version 2.0 (the "License");
  # you may not use this file except in compliance with the License.
  # You may obtain a copy of the License at
  #
  # http://www.apache.org/licenses/LICENSE-2.0
  #
  # Unless required by applicable law or agreed to in writing, software
  # distributed under the License is distributed on an "AS IS" BASIS,
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  # See the License for the specific language governing permissions and
  # limitations under the License.
   
  # Unit tests for the transitfeed module.
   
  import datetime
  from datetime import date
  import dircache
  import os.path
  import re
  import sys
  import tempfile
  import time
  import transitfeed
  import types
  import unittest
  import util
  from util import RecordingProblemAccumulator
  from StringIO import StringIO
  import zipfile
  import zlib
   
   
  def DataPath(path):
  here = os.path.dirname(__file__)
  return os.path.join(here, 'data', path)
   
  def GetDataPathContents():
  here = os.path.dirname(__file__)
  return dircache.listdir(os.path.join(here, 'data'))
   
   
  class ExceptionProblemReporterNoExpiration(transitfeed.ProblemReporter):
  """Ignores feed expiration problems.
   
  Use TestFailureProblemReporter in new code because it fails more cleanly, is
  easier to extend and does more thorough checking.
  """
   
  def __init__(self):
  accumulator = transitfeed.ExceptionProblemAccumulator(raise_warnings=True)
  transitfeed.ProblemReporter.__init__(self, accumulator)
   
  def ExpirationDate(self, expiration, context=None):
  pass # We don't want to give errors about our test data
   
   
  def GetTestFailureProblemReporter(test_case,
  ignore_types=("ExpirationDate",)):
  accumulator = TestFailureProblemAccumulator(test_case, ignore_types)
  problems = transitfeed.ProblemReporter(accumulator)
  return problems
   
   
  class TestFailureProblemAccumulator(transitfeed.ProblemAccumulatorInterface):
  """Causes a test failure immediately on any problem."""
  def __init__(self, test_case, ignore_types=("ExpirationDate",)):
  self.test_case = test_case
  self._ignore_types = ignore_types or set()
   
  def _Report(self, e):
  # These should never crash
  formatted_problem = e.FormatProblem()
  formatted_context = e.FormatContext()
  exception_class = e.__class__.__name__
  if exception_class in self._ignore_types:
  return
  self.test_case.fail(
  "%s: %s\n%s" % (exception_class, formatted_problem, formatted_context))
   
   
  class UnrecognizedColumnRecorder(transitfeed.ProblemReporter):
  """Keeps track of unrecognized column errors."""
  def __init__(self, test_case):
  self.accumulator = RecordingProblemAccumulator(test_case,
  ignore_types=("ExpirationDate",))
  self.column_errors = []
   
  def UnrecognizedColumn(self, file_name, column_name, context=None):
  self.column_errors.append((file_name, column_name))
   
   
  class RedirectStdOutTestCaseBase(util.TestCase):
  """Save stdout to the StringIO buffer self.this_stdout"""
  def setUp(self):
  self.saved_stdout = sys.stdout
  self.this_stdout = StringIO()
  sys.stdout = self.this_stdout
   
  def tearDown(self):
  sys.stdout = self.saved_stdout
  self.this_stdout.close()
   
   
  # ensure that there are no exceptions when attempting to load
  # (so that the validator won't crash)
  class NoExceptionTestCase(RedirectStdOutTestCaseBase):
  def runTest(self):
  for feed in GetDataPathContents():
  loader = transitfeed.Loader(DataPath(feed),
  problems=transitfeed.ProblemReporter(),
  extra_validation=True)
  schedule = loader.Load()
  schedule.Validate()
   
   
  class EndOfLineCheckerTestCase(util.TestCase):
  def setUp(self):
  self.accumulator = RecordingProblemAccumulator(self)
  self.problems = transitfeed.ProblemReporter(self.accumulator)
   
  def RunEndOfLineChecker(self, end_of_line_checker):
  # Iterating using for calls end_of_line_checker.next() until a
  # StopIteration is raised. EndOfLineChecker does the final check for a mix
  # of CR LF and LF ends just before raising StopIteration.
  for line in end_of_line_checker:
  pass
   
  def testInvalidLineEnd(self):
  f = transitfeed.EndOfLineChecker(StringIO("line1\r\r\nline2"),
  "<StringIO>",
  self.problems)
  self.RunEndOfLineChecker(f)
  e = self.accumulator.PopException("InvalidLineEnd")
  self.assertEqual(e.file_name, "<StringIO>")
  self.assertEqual(e.row_num, 1)
  self.assertEqual(e.bad_line_end, r"\r\r\n")
  self.accumulator.AssertNoMoreExceptions()
   
  def testInvalidLineEndToo(self):
  f = transitfeed.EndOfLineChecker(
  StringIO("line1\nline2\r\nline3\r\r\r\n"),
  "<StringIO>", self.problems)
  self.RunEndOfLineChecker(f)
  e = self.accumulator.PopException("InvalidLineEnd")
  self.assertEqual(e.file_name, "<StringIO>")
  self.assertEqual(e.row_num, 3)
  self.assertEqual(e.bad_line_end, r"\r\r\r\n")
  e = self.accumulator.PopException("OtherProblem")
  self.assertEqual(e.file_name, "<StringIO>")
  self.assertTrue(e.description.find("consistent line end") != -1)
  self.accumulator.AssertNoMoreExceptions()
   
  def testEmbeddedCr(self):
  f = transitfeed.EndOfLineChecker(
  StringIO("line1\rline1b"),
  "<StringIO>", self.problems)
  self.RunEndOfLineChecker(f)
  e = self.accumulator.PopException("OtherProblem")
  self.assertEqual(e.file_name, "<StringIO>")
  self.assertEqual(e.row_num, 1)
  self.assertEqual(e.FormatProblem(),
  "Line contains ASCII Carriage Return 0x0D, \\r")
  self.accumulator.AssertNoMoreExceptions()
   
  def testEmbeddedUtf8NextLine(self):
  f = transitfeed.EndOfLineChecker(
  StringIO("line1b\xc2\x85"),
  "<StringIO>", self.problems)
  self.RunEndOfLineChecker(f)
  e = self.accumulator.PopException("OtherProblem")
  self.assertEqual(e.file_name, "<StringIO>")
  self.assertEqual(e.row_num, 1)
  self.assertEqual(e.FormatProblem(),
  "Line contains Unicode NEXT LINE SEPARATOR U+0085")
  self.accumulator.AssertNoMoreExceptions()
   
  def testEndOfLineMix(self):
  f = transitfeed.EndOfLineChecker(
  StringIO("line1\nline2\r\nline3\nline4"),
  "<StringIO>", self.problems)
  self.RunEndOfLineChecker(f)
  e = self.accumulator.PopException("OtherProblem")
  self.assertEqual(e.file_name, "<StringIO>")
  self.assertEqual(e.FormatProblem(),
  "Found 1 CR LF \"\\r\\n\" line end (line 2) and "
  "2 LF \"\\n\" line ends (lines 1, 3). A file must use a "
  "consistent line end.")
  self.accumulator.AssertNoMoreExceptions()
   
  def testEndOfLineManyMix(self):
  f = transitfeed.EndOfLineChecker(
  StringIO("1\n2\n3\n4\n5\n6\n7\r\n8\r\n9\r\n10\r\n11\r\n"),
  "<StringIO>", self.problems)
  self.RunEndOfLineChecker(f)
  e = self.accumulator.PopException("OtherProblem")
  self.assertEqual(e.file_name, "<StringIO>")
  self.assertEqual(e.FormatProblem(),
  "Found 5 CR LF \"\\r\\n\" line ends (lines 7, 8, 9, 10, "
  "11) and 6 LF \"\\n\" line ends (lines 1, 2, 3, 4, 5, "
  "...). A file must use a consistent line end.")
  self.accumulator.AssertNoMoreExceptions()
   
  def testLoad(self):
  loader = transitfeed.Loader(
  DataPath("bad_eol.zip"), problems=self.problems, extra_validation=True)
  loader.Load()
   
  e = self.accumulator.PopException("OtherProblem")
  self.assertEqual(e.file_name, "calendar.txt")
  self.assertTrue(re.search(
  r"Found 1 CR LF.* \(line 2\) and 2 LF .*\(lines 1, 3\)",
  e.FormatProblem()))
   
  e = self.accumulator.PopException("InvalidLineEnd")
  self.assertEqual(e.file_name, "routes.txt")
  self.assertEqual(e.row_num, 5)
  self.assertTrue(e.FormatProblem().find(r"\r\r\n") != -1)
   
  e = self.accumulator.PopException("OtherProblem")
  self.assertEqual(e.file_name, "trips.txt")
  self.assertEqual(e.row_num, 1)
  self.assertTrue(re.search(
  r"contains ASCII Form Feed",
  e.FormatProblem()))
  # TODO(Tom): avoid this duplicate error for the same issue
  e = self.accumulator.PopException("CsvSyntax")
  self.assertEqual(e.row_num, 1)
  self.assertTrue(re.search(
  r"header row should not contain any space char",
  e.FormatProblem()))
   
  self.accumulator.AssertNoMoreExceptions()
   
   
  class LoadTestCase(util.TestCase):
  def setUp(self):
  self.accumulator = RecordingProblemAccumulator(self, ("ExpirationDate",))
  self.problems = transitfeed.ProblemReporter(self.accumulator)
   
  def Load(self, feed_name):
  loader = transitfeed.Loader(
  DataPath(feed_name), problems=self.problems, extra_validation=True)
  loader.Load()
   
  def ExpectInvalidValue(self, feed_name, column_name):
  self.Load(feed_name)
  self.accumulator.PopInvalidValue(column_name)
  self.accumulator.AssertNoMoreExceptions()
   
  def ExpectMissingFile(self, feed_name, file_name):
  self.Load(feed_name)
  e = self.accumulator.PopException("MissingFile")
  self.assertEqual(file_name, e.file_name)
  # Don't call AssertNoMoreExceptions() because a missing file causes
  # many errors.
   
   
  class LoadFromZipTestCase(util.TestCase):
  def runTest(self):
  loader = transitfeed.Loader(
  DataPath('good_feed.zip'),
  problems = GetTestFailureProblemReporter(self),
  extra_validation = True)
  loader.Load()
   
  # now try using Schedule.Load
  schedule = transitfeed.Schedule(
  problem_reporter=ExceptionProblemReporterNoExpiration())
  schedule.Load(DataPath('good_feed.zip'), extra_validation=True)
   
   
  class LoadAndRewriteFromZipTestCase(util.TestCase):
  def runTest(self):
  schedule = transitfeed.Schedule(
  problem_reporter=ExceptionProblemReporterNoExpiration())
  schedule.Load(DataPath('good_feed.zip'), extra_validation=True)
   
  # Finally see if write crashes
  schedule.WriteGoogleTransitFeed(tempfile.TemporaryFile())
   
   
  class LoadFromDirectoryTestCase(util.TestCase):
  def runTest(self):
  loader = transitfeed.Loader(
  DataPath('good_feed'),
  problems = GetTestFailureProblemReporter(self),
  extra_validation = True)
  loader.Load()
   
   
  class LoadUnknownFeedTestCase(util.TestCase):
  def runTest(self):
  feed_name = DataPath('unknown_feed')
  loader = transitfeed.Loader(
  feed_name,
  problems = ExceptionProblemReporterNoExpiration(),
  extra_validation = True)
  try:
  loader.Load()
  self.fail('FeedNotFound exception expected')
  except transitfeed.FeedNotFound, e:
  self.assertEqual(feed_name, e.feed_name)
   
  class LoadUnknownFormatTestCase(util.TestCase):
  def runTest(self):
  feed_name = DataPath('unknown_format.zip')
  loader = transitfeed.Loader(
  feed_name,
  problems = ExceptionProblemReporterNoExpiration(),
  extra_validation = True)
  try:
  loader.Load()
  self.fail('UnknownFormat exception expected')
  except transitfeed.UnknownFormat, e:
  self.assertEqual(feed_name, e.feed_name)
   
  class LoadUnrecognizedColumnsTestCase(util.TestCase):
  def runTest(self):
  problems = UnrecognizedColumnRecorder(self)
  loader = transitfeed.Loader(DataPath('unrecognized_columns'),
  problems=problems)
  loader.Load()
  found_errors = set(problems.column_errors)
  expected_errors = set([
  ('agency.txt', 'agency_lange'),
  ('stops.txt', 'stop_uri'),
  ('routes.txt', 'Route_Text_Color'),
  ('calendar.txt', 'leap_day'),
  ('calendar_dates.txt', 'leap_day'),
  ('trips.txt', 'sharpe_id'),
  ('stop_times.txt', 'shapedisttraveled'),
  ('stop_times.txt', 'drop_off_time'),
  ('fare_attributes.txt', 'transfer_time'),
  ('fare_rules.txt', 'source_id'),
  ('frequencies.txt', 'superfluous'),
  ('transfers.txt', 'to_stop')
  ])
   
  # Now make sure we got the unrecognized column errors that we expected.
  not_expected = found_errors.difference(expected_errors)
  self.failIf(not_expected, 'unexpected errors: %s' % str(not_expected))
  not_found = expected_errors.difference(found_errors)
  self.failIf(not_found, 'expected but not found: %s' % str(not_found))
   
  class LoadExtraCellValidationTestCase(LoadTestCase):
  """Check that the validation detects too many cells in a row."""
  def runTest(self):
  self.Load('extra_row_cells')
  e = self.accumulator.PopException("OtherProblem")
  self.assertEquals("routes.txt", e.file_name)
  self.assertEquals(4, e.row_num)
  self.accumulator.AssertNoMoreExceptions()
   
   
  class LoadMissingCellValidationTestCase(LoadTestCase):
  """Check that the validation detects missing cells in a row."""
  def runTest(self):
  self.Load('missing_row_cells')
  e = self.accumulator.PopException("OtherProblem")