|
#!/usr/bin/python2.5 |
|
|
|
# Copyright (C) 2007 Google Inc. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
# Unit tests for the transitfeed module. |
|
|
|
import datetime |
|
from datetime import date |
|
import dircache |
|
import os.path |
|
import re |
|
import sys |
|
import tempfile |
|
import time |
|
import transitfeed |
|
import types |
|
import unittest |
|
import util |
|
from util import RecordingProblemAccumulator |
|
from StringIO import StringIO |
|
import zipfile |
|
import zlib |
|
|
|
|
|
def DataPath(path): |
|
here = os.path.dirname(__file__) |
|
return os.path.join(here, 'data', path) |
|
|
|
def GetDataPathContents(): |
|
here = os.path.dirname(__file__) |
|
return dircache.listdir(os.path.join(here, 'data')) |
|
|
|
|
|
class ExceptionProblemReporterNoExpiration(transitfeed.ProblemReporter): |
|
"""Ignores feed expiration problems. |
|
|
|
Use TestFailureProblemReporter in new code because it fails more cleanly, is |
|
easier to extend and does more thorough checking. |
|
""" |
|
|
|
def __init__(self): |
|
accumulator = transitfeed.ExceptionProblemAccumulator(raise_warnings=True) |
|
transitfeed.ProblemReporter.__init__(self, accumulator) |
|
|
|
def ExpirationDate(self, expiration, context=None): |
|
pass # We don't want to give errors about our test data |
|
|
|
|
|
def GetTestFailureProblemReporter(test_case, |
|
ignore_types=("ExpirationDate",)): |
|
accumulator = TestFailureProblemAccumulator(test_case, ignore_types) |
|
problems = transitfeed.ProblemReporter(accumulator) |
|
return problems |
|
|
|
|
|
class TestFailureProblemAccumulator(transitfeed.ProblemAccumulatorInterface): |
|
"""Causes a test failure immediately on any problem.""" |
|
def __init__(self, test_case, ignore_types=("ExpirationDate",)): |
|
self.test_case = test_case |
|
self._ignore_types = ignore_types or set() |
|
|
|
def _Report(self, e): |
|
# These should never crash |
|
formatted_problem = e.FormatProblem() |
|
formatted_context = e.FormatContext() |
|
exception_class = e.__class__.__name__ |
|
if exception_class in self._ignore_types: |
|
return |
|
self.test_case.fail( |
|
"%s: %s\n%s" % (exception_class, formatted_problem, formatted_context)) |
|
|
|
|
|
class UnrecognizedColumnRecorder(transitfeed.ProblemReporter): |
|
"""Keeps track of unrecognized column errors.""" |
|
def __init__(self, test_case): |
|
self.accumulator = RecordingProblemAccumulator(test_case, |
|
ignore_types=("ExpirationDate",)) |
|
self.column_errors = [] |
|
|
|
def UnrecognizedColumn(self, file_name, column_name, context=None): |
|
self.column_errors.append((file_name, column_name)) |
|
|
|
|
|
class RedirectStdOutTestCaseBase(util.TestCase): |
|
"""Save stdout to the StringIO buffer self.this_stdout""" |
|
def setUp(self): |
|
self.saved_stdout = sys.stdout |
|
self.this_stdout = StringIO() |
|
sys.stdout = self.this_stdout |
|
|
|
def tearDown(self): |
|
sys.stdout = self.saved_stdout |
|
self.this_stdout.close() |
|
|
|
|
|
# ensure that there are no exceptions when attempting to load |
|
# (so that the validator won't crash) |
|
class NoExceptionTestCase(RedirectStdOutTestCaseBase): |
|
def runTest(self): |
|
for feed in GetDataPathContents(): |
|
loader = transitfeed.Loader(DataPath(feed), |
|
problems=transitfeed.ProblemReporter(), |
|
extra_validation=True) |
|
schedule = loader.Load() |
|
schedule.Validate() |
|
|
|
|
|
class EndOfLineCheckerTestCase(util.TestCase): |
|
def setUp(self): |
|
self.accumulator = RecordingProblemAccumulator(self) |
|
self.problems = transitfeed.ProblemReporter(self.accumulator) |
|
|
|
def RunEndOfLineChecker(self, end_of_line_checker): |
|
# Iterating using for calls end_of_line_checker.next() until a |
|
# StopIteration is raised. EndOfLineChecker does the final check for a mix |
|
# of CR LF and LF ends just before raising StopIteration. |
|
for line in end_of_line_checker: |
|
pass |
|
|
|
def testInvalidLineEnd(self): |
|
f = transitfeed.EndOfLineChecker(StringIO("line1\r\r\nline2"), |
|
"<StringIO>", |
|
self.problems) |
|
self.RunEndOfLineChecker(f) |
|
e = self.accumulator.PopException("InvalidLineEnd") |
|
self.assertEqual(e.file_name, "<StringIO>") |
|
self.assertEqual(e.row_num, 1) |
|
self.assertEqual(e.bad_line_end, r"\r\r\n") |
|
self.accumulator.AssertNoMoreExceptions() |
|
|
|
def testInvalidLineEndToo(self): |
|
f = transitfeed.EndOfLineChecker( |
|
StringIO("line1\nline2\r\nline3\r\r\r\n"), |
|
"<StringIO>", self.problems) |
|
self.RunEndOfLineChecker(f) |
|
e = self.accumulator.PopException("InvalidLineEnd") |
|
self.assertEqual(e.file_name, "<StringIO>") |
|
self.assertEqual(e.row_num, 3) |
|
self.assertEqual(e.bad_line_end, r"\r\r\r\n") |
|
e = self.accumulator.PopException("OtherProblem") |
|
self.assertEqual(e.file_name, "<StringIO>") |
|
self.assertTrue(e.description.find("consistent line end") != -1) |
|
self.accumulator.AssertNoMoreExceptions() |
|
|
|
def testEmbeddedCr(self): |
|
f = transitfeed.EndOfLineChecker( |
|
StringIO("line1\rline1b"), |
|
"<StringIO>", self.problems) |
|
self.RunEndOfLineChecker(f) |
|
e = self.accumulator.PopException("OtherProblem") |
|
self.assertEqual(e.file_name, "<StringIO>") |
|
self.assertEqual(e.row_num, 1) |
|
self.assertEqual(e.FormatProblem(), |
|
"Line contains ASCII Carriage Return 0x0D, \\r") |
|
self.accumulator.AssertNoMoreExceptions() |
|
|
|
def testEmbeddedUtf8NextLine(self): |
|
f = transitfeed.EndOfLineChecker( |
|
StringIO("line1b\xc2\x85"), |
|
"<StringIO>", self.problems) |
|
self.RunEndOfLineChecker(f) |
|
e = self.accumulator.PopException("OtherProblem") |
|
self.assertEqual(e.file_name, "<StringIO>") |
|
self.assertEqual(e.row_num, 1) |
|
self.assertEqual(e.FormatProblem(), |
|
"Line contains Unicode NEXT LINE SEPARATOR U+0085") |
|
self.accumulator.AssertNoMoreExceptions() |
|
|
|
def testEndOfLineMix(self): |
|
f = transitfeed.EndOfLineChecker( |
|
StringIO("line1\nline2\r\nline3\nline4"), |
|
"<StringIO>", self.problems) |
|
self.RunEndOfLineChecker(f) |
|
e = self.accumulator.PopException("OtherProblem") |
|
self.assertEqual(e.file_name, "<StringIO>") |
|
self.assertEqual(e.FormatProblem(), |
|
"Found 1 CR LF \"\\r\\n\" line end (line 2) and " |
|
"2 LF \"\\n\" line ends (lines 1, 3). A file must use a " |
|
"consistent line end.") |
|
self.accumulator.AssertNoMoreExceptions() |
|
|
|
def testEndOfLineManyMix(self): |
|
f = transitfeed.EndOfLineChecker( |
|
StringIO("1\n2\n3\n4\n5\n6\n7\r\n8\r\n9\r\n10\r\n11\r\n"), |
|
"<StringIO>", self.problems) |
|
self.RunEndOfLineChecker(f) |
|
e = self.accumulator.PopException("OtherProblem") |
|
self.assertEqual(e.file_name, "<StringIO>") |
|
self.assertEqual(e.FormatProblem(), |
|
"Found 5 CR LF \"\\r\\n\" line ends (lines 7, 8, 9, 10, " |
|
"11) and 6 LF \"\\n\" line ends (lines 1, 2, 3, 4, 5, " |
|
"...). A file must use a consistent line end.") |
|
self.accumulator.AssertNoMoreExceptions() |
|
|
|
def testLoad(self): |
|
loader = transitfeed.Loader( |
|
DataPath("bad_eol.zip"), problems=self.problems, extra_validation=True) |
|
loader.Load() |
|
|
|
e = self.accumulator.PopException("OtherProblem") |
|
self.assertEqual(e.file_name, "calendar.txt") |
|
self.assertTrue(re.search( |
|
r"Found 1 CR LF.* \(line 2\) and 2 LF .*\(lines 1, 3\)", |
|
e.FormatProblem())) |
|
|
|
e = self.accumulator.PopException("InvalidLineEnd") |
|
self.assertEqual(e.file_name, "routes.txt") |
|
self.assertEqual(e.row_num, 5) |
|
self.assertTrue(e.FormatProblem().find(r"\r\r\n") != -1) |
|
|
|
e = self.accumulator.PopException("OtherProblem") |
|
self.assertEqual(e.file_name, "trips.txt") |
|
self.assertEqual(e.row_num, 1) |
|
self.assertTrue(re.search( |
|
r"contains ASCII Form Feed", |
|
e.FormatProblem())) |
|
# TODO(Tom): avoid this duplicate error for the same issue |
|
e = self.accumulator.PopException("CsvSyntax") |
|
self.assertEqual(e.row_num, 1) |
|
self.assertTrue(re.search( |
|
r"header row should not contain any space char", |
|
e.FormatProblem())) |
|
|
|
self.accumulator.AssertNoMoreExceptions() |
|
|
|
|
|
class LoadTestCase(util.TestCase): |
|
def setUp(self): |
|
self.accumulator = RecordingProblemAccumulator(self, ("ExpirationDate",)) |
|
self.problems = transitfeed.ProblemReporter(self.accumulator) |
|
|
|
def Load(self, feed_name): |
|
loader = transitfeed.Loader( |
|
DataPath(feed_name), problems=self.problems, extra_validation=True) |
|
loader.Load() |
|
|
|
def ExpectInvalidValue(self, feed_name, column_name): |
|
self.Load(feed_name) |
|
self.accumulator.PopInvalidValue(column_name) |
|
self.accumulator.AssertNoMoreExceptions() |
|
|
|
def ExpectMissingFile(self, feed_name, file_name): |
|
self.Load(feed_name) |
|
e = self.accumulator.PopException("MissingFile") |
|
self.assertEqual(file_name, e.file_name) |
|
# Don't call AssertNoMoreExceptions() because a missing file causes |
|
# many errors. |
|
|
|
|
|
class LoadFromZipTestCase(util.TestCase): |
|
def runTest(self): |
|
loader = transitfeed.Loader( |
|
DataPath('good_feed.zip'), |
|
problems = GetTestFailureProblemReporter(self), |
|
extra_validation = True) |
|
loader.Load() |
|
|
|
# now try using Schedule.Load |
|
schedule = transitfeed.Schedule( |
|
problem_reporter=ExceptionProblemReporterNoExpiration()) |
|
schedule.Load(DataPath('good_feed.zip'), extra_validation=True) |
|
|
|
|
|
class LoadAndRewriteFromZipTestCase(util.TestCase): |
|
def runTest(self): |
|
schedule = transitfeed.Schedule( |
|
problem_reporter=ExceptionProblemReporterNoExpiration()) |
|
schedule.Load(DataPath('good_feed.zip'), extra_validation=True) |
|
|
|
# Finally see if write crashes |
|
schedule.WriteGoogleTransitFeed(tempfile.TemporaryFile()) |
|
|
|
|
|
class LoadFromDirectoryTestCase(util.TestCase): |
|
def runTest(self): |
|
loader = transitfeed.Loader( |
|
DataPath('good_feed'), |
|
problems = GetTestFailureProblemReporter(self), |
|
extra_validation = True) |
|
loader.Load() |
|
|
|
|
|
class LoadUnknownFeedTestCase(util.TestCase): |
|
def runTest(self): |
|
feed_name = DataPath('unknown_feed') |
|
loader = transitfeed.Loader( |
|
feed_name, |
|
problems = ExceptionProblemReporterNoExpiration(), |
|
extra_validation = True) |
|
try: |
|
loader.Load() |
|
self.fail('FeedNotFound exception expected') |
|
except transitfeed.FeedNotFound, e: |
|
self.assertEqual(feed_name, e.feed_name) |
|
|
|
class LoadUnknownFormatTestCase(util.TestCase): |
|
def runTest(self): |
|
feed_name = DataPath('unknown_format.zip') |
|
loader = transitfeed.Loader( |
|
feed_name, |
|
problems = ExceptionProblemReporterNoExpiration(), |
|
extra_validation = True) |
|
try: |
|
loader.Load() |
|
self.fail('UnknownFormat exception expected') |
|
except transitfeed.UnknownFormat, e: |
|
self.assertEqual(feed_name, e.feed_name) |
|
|
|
class LoadUnrecognizedColumnsTestCase(util.TestCase): |
|
def runTest(self): |
|
problems = UnrecognizedColumnRecorder(self) |
|
loader = transitfeed.Loader(DataPath('unrecognized_columns'), |
|
problems=problems) |
|
loader.Load() |
|
found_errors = set(problems.column_errors) |
|
expected_errors = set([ |
|
('agency.txt', 'agency_lange'), |
|
('stops.txt', 'stop_uri'), |
|
('routes.txt', 'Route_Text_Color'), |
|
('calendar.txt', 'leap_day'), |
|
('calendar_dates.txt', 'leap_day'), |
|
('trips.txt', 'sharpe_id'), |
|
('stop_times.txt', 'shapedisttraveled'), |
|
('stop_times.txt', 'drop_off_time'), |
|
('fare_attributes.txt', 'transfer_time'), |
|
('fare_rules.txt', 'source_id'), |
|
('frequencies.txt', 'superfluous'), |
|
('transfers.txt', 'to_stop') |
|
]) |
|
|
|
# Now make sure we got the unrecognized column errors that we expected. |
|
not_expected = found_errors.difference(expected_errors) |
|
self.failIf(not_expected, 'unexpected errors: %s' % str(not_expected)) |
|
not_found = expected_errors.difference(found_errors) |
|
self.failIf(not_found, 'expected but not found: %s' % str(not_found)) |
|
|
|
class LoadExtraCellValidationTestCase(LoadTestCase): |
|
"""Check that the validation detects too many cells in a row.""" |
|
def runTest(self): |
|
self.Load('extra_row_cells') |
|
e = self.accumulator.PopException("OtherProblem") |
|
self.assertEquals("routes.txt", e.file_name) |
|
self.assertEquals(4, e.row_num) |
|
self.accumulator.AssertNoMoreExceptions() |
|
|
|
|
|
class LoadMissingCellValidationTestCase(LoadTestCase): |
|
"""Check that the validation detects missing cells in a row.""" |
|
def runTest(self): |
|
self.Load('missing_row_cells') |
|
e = self.accumulator.PopException("OtherProblem") |
|