Upgrade origin-src to google transit feed 1.2.6
[bus.git] / origin-src / transitfeed-1.2.6 / transitfeed / loader.py
blob:a/origin-src/transitfeed-1.2.6/transitfeed/loader.py -> blob:b/origin-src/transitfeed-1.2.6/transitfeed/loader.py
--- a/origin-src/transitfeed-1.2.6/transitfeed/loader.py
+++ b/origin-src/transitfeed-1.2.6/transitfeed/loader.py
@@ -1,1 +1,567 @@
-
+#!/usr/bin/python2.5
+
+# Copyright (C) 2007 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import codecs
+import cStringIO as StringIO
+import csv
+import os
+import re
+import zipfile
+
+import gtfsfactory as gtfsfactory_module
+import problems
+import util
+
+class Loader:
+  def __init__(self,
+               feed_path=None,
+               schedule=None,
+               problems=problems.default_problem_reporter,
+               extra_validation=False,
+               load_stop_times=True,
+               memory_db=True,
+               zip=None,
+               check_duplicate_trips=False,
+               gtfs_factory=None):
+    """Initialize a new Loader object.
+
+    Args:
+      feed_path: string path to a zip file or directory
+      schedule: a Schedule object or None to have one created
+      problems: a ProblemReporter object, the default reporter raises an
+        exception for each problem
+      extra_validation: True if you would like extra validation
+      load_stop_times: load the stop_times table, used to speed load time when
+        times are not needed. The default is True.
+      memory_db: if creating a new Schedule object use an in-memory sqlite
+        database instead of creating one in a temporary file
+      zip: a zipfile.ZipFile object, optionally used instead of path
+    """
+    if gtfs_factory is None:
+      gtfs_factory = gtfsfactory_module.GetGtfsFactory()
+
+    if not schedule:
+      schedule = gtfs_factory.Schedule(problem_reporter=problems,
+          memory_db=memory_db, check_duplicate_trips=check_duplicate_trips)
+
+    self._extra_validation = extra_validation
+    self._schedule = schedule
+    self._problems = problems
+    self._path = feed_path
+    self._zip = zip
+    self._load_stop_times = load_stop_times
+    self._gtfs_factory = gtfs_factory
+
+  def _DetermineFormat(self):
+    """Determines whether the feed is in a form that we understand, and
+       if so, returns True."""
+    if self._zip:
+      # If zip was passed to __init__ then path isn't used
+      assert not self._path
+      return True
+
+    if not isinstance(self._path, basestring) and hasattr(self._path, 'read'):
+      # A file-like object, used for testing with a StringIO file
+      self._zip = zipfile.ZipFile(self._path, mode='r')
+      return True
+
+    if not os.path.exists(self._path):
+      self._problems.FeedNotFound(self._path)
+      return False
+
+    if self._path.endswith('.zip'):
+      try:
+        self._zip = zipfile.ZipFile(self._path, mode='r')
+      except IOError:  # self._path is a directory
+        pass
+      except zipfile.BadZipfile:
+        self._problems.UnknownFormat(self._path)
+        return False
+
+    if not self._zip and not os.path.isdir(self._path):
+      self._problems.UnknownFormat(self._path)
+      return False
+
+    return True
+
+  def _GetFileNames(self):
+    """Returns a list of file names in the feed."""
+    if self._zip:
+      return self._zip.namelist()
+    else:
+      return os.listdir(self._path)
+
+  def _CheckFileNames(self):
+    filenames = self._GetFileNames()
+    known_filenames = self._gtfs_factory.GetKnownFilenames()
+    for feed_file in filenames:
+      if feed_file not in known_filenames:
+        if not feed_file.startswith('.'):
+          # Don't worry about .svn files and other hidden files
+          # as this will break the tests.
+          self._problems.UnknownFile(feed_file)
+
+  def _GetUtf8Contents(self, file_name):
+    """Check for errors in file_name and return a string for csv reader."""
+    contents = self._FileContents(file_name)
+    if not contents:  # Missing file
+      return
+
+    # Check for errors that will prevent csv.reader from working
+    if len(contents) >= 2 and contents[0:2] in (codecs.BOM_UTF16_BE,
+        codecs.BOM_UTF16_LE):
+      self._problems.FileFormat("appears to be encoded in utf-16", (file_name, ))
+      # Convert and continue, so we can find more errors
+      contents = codecs.getdecoder('utf-16')(contents)[0].encode('utf-8')
+
+    null_index = contents.find('\0')
+    if null_index != -1:
+      # It is easier to get some surrounding text than calculate the exact
+      # row_num
+      m = re.search(r'.{,20}\0.{,20}', contents, re.DOTALL)
+      self._problems.FileFormat(
+          "contains a null in text \"%s\" at byte %d" %
+          (codecs.getencoder('string_escape')(m.group()), null_index + 1),
+          (file_name, ))
+      return
+
+    # strip out any UTF-8 Byte Order Marker (otherwise it'll be
+    # treated as part of the first column name, causing a mis-parse)
+    contents = contents.lstrip(codecs.BOM_UTF8)
+    return contents
+
+  def _ReadCsvDict(self, file_name, all_cols, required):
+    """Reads lines from file_name, yielding a dict of unicode values."""
+    assert file_name.endswith(".txt")
+    table_name = file_name[0:-4]
+    contents = self._GetUtf8Contents(file_name)
+    if not contents:
+      return
+
+    eol_checker = util.EndOfLineChecker(StringIO.StringIO(contents),
+                                   file_name, self._problems)
+    # The csv module doesn't provide a way to skip trailing space, but when I
+    # checked 15/675 feeds had trailing space in a header row and 120 had spaces
+    # after fields. Space after header fields can cause a serious parsing
+    # problem, so warn. Space after body fields can cause a problem time,
+    # integer and id fields; they will be validated at higher levels.
+    reader = csv.reader(eol_checker, skipinitialspace=True)
+
+    raw_header = reader.next()
+    header_occurrences = util.defaultdict(lambda: 0)
+    header = []
+    valid_columns = []  # Index into raw_header and raw_row
+    for i, h in enumerate(raw_header):
+      h_stripped = h.strip()
+      if not h_stripped:
+        self._problems.CsvSyntax(
+            description="The header row should not contain any blank values. "
+                        "The corresponding column will be skipped for the "
+                        "entire file.",
+            context=(file_name, 1, [''] * len(raw_header), raw_header),
+            type=problems.TYPE_ERROR)
+        continue
+      elif h != h_stripped:
+        self._problems.CsvSyntax(
+            description="The header row should not contain any "
+                        "space characters.",
+            context=(file_name, 1, [''] * len(raw_header), raw_header),
+            type=problems.TYPE_WARNING)
+      header.append(h_stripped)
+      valid_columns.append(i)
+      header_occurrences[h_stripped] += 1
+
+    for name, count in header_occurrences.items():
+      if count > 1:
+        self._problems.DuplicateColumn(
+            header=name,
+            file_name=file_name,
+            count=count)
+
+    self._schedule._table_columns[table_name] = header
+
+    # check for unrecognized columns, which are often misspellings
+    unknown_cols = set(header) - set(all_cols)
+    if len(unknown_cols) == len(header):
+      self._problems.CsvSyntax(
+            description="The header row did not contain any known column "
+                        "names. The file is most likely missing the header row "
+                        "or not in the expected CSV format.",
+            context=(file_name, 1, [''] * len(raw_header), raw_header),
+            type=problems.TYPE_ERROR)
+    else:
+      for col in unknown_cols:
+        # this is provided in order to create a nice colored list of
+        # columns in the validator output
+        context = (file_name, 1, [''] * len(header), header)
+        self._problems.UnrecognizedColumn(file_name, col, context)
+
+    missing_cols = set(required) - set(header)
+    for col in missing_cols:
+      # this is provided in order to create a nice colored list of
+      # columns in the validator output
+      context = (file_name, 1, [''] * len(header), header)
+      self._problems.MissingColumn(file_name, col, context)
+
+    line_num = 1  # First line read by reader.next() above
+    for raw_row in reader:
+      line_num += 1
+      if len(raw_row) == 0:  # skip extra empty lines in file
+        continue
+
+      if len(raw_row) > len(raw_header):
+        self._problems.OtherProblem('Found too many cells (commas) in line '
+                                    '%d of file "%s".  Every row in the file '
+                                    'should have the same number of cells as '
+                                    'the header (first line) does.' %
+                                    (line_num, file_name),
+                                    (file_name, line_num),
+                                    type=problems.TYPE_WARNING)
+
+      if len(raw_row) < len(raw_header):
+        self._problems.OtherProblem('Found missing cells (commas) in line '
+                                    '%d of file "%s".  Every row in the file '
+                                    'should have the same number of cells as '
+                                    'the header (first line) does.' %
+                                    (line_num, file_name),
+                                    (file_name, line_num),
+                                    type=problems.TYPE_WARNING)
+
+      # raw_row is a list of raw bytes which should be valid utf-8. Convert each
+      # valid_columns of raw_row into Unicode.
+      valid_values = []
+      unicode_error_columns = []  # index of valid_values elements with an error
+      for i in valid_columns:
+        try:
+          valid_values.append(raw_row[i].decode('utf-8'))
+        except UnicodeDecodeError:
+          # Replace all invalid characters with REPLACEMENT CHARACTER (U+FFFD)
+          valid_values.append(codecs.getdecoder("utf8")
+                              (raw_row[i], errors="replace")[0])
+          unicode_error_columns.append(len(valid_values) - 1)
+        except IndexError:
+          break
+
+      # The error report may contain a dump of all values in valid_values so
+      # problems can not be reported until after converting all of raw_row to
+      # Unicode.
+      for i in unicode_error_columns:
+        self._problems.InvalidValue(header[i], valid_values[i],
+                                    'Unicode error',
+                                    (file_name, line_num,
+                                     valid_values, header))
+
+
+      d = dict(zip(header, valid_values))
+      yield (d, line_num, header, valid_values)
+
+  # TODO: Add testing for this specific function
+  def _ReadCSV(self, file_name, cols, required):
+    """Reads lines from file_name, yielding a list of unicode values
+    corresponding to the column names in cols."""
+    contents = self._GetUtf8Contents(file_name)
+    if not contents:
+      return
+
+    eol_checker = util.EndOfLineChecker(StringIO.StringIO(contents),
+                                   file_name, self._problems)
+    reader = csv.reader(eol_checker)  # Use excel dialect
+
+    header = reader.next()
+    header = map(lambda x: x.strip(), header)  # trim any whitespace
+    header_occurrences = util.defaultdict(lambda: 0)
+    for column_header in header:
+      header_occurrences[column_header] += 1
+
+    for name, count in header_occurrences.items():
+      if count > 1:
+        self._problems.DuplicateColumn(
+            header=name,
+            file_name=file_name,
+            count=count)
+
+    # check for unrecognized columns, which are often misspellings
+    unknown_cols = set(header).difference(set(cols))
+    for col in unknown_cols:
+      # this is provided in order to create a nice colored list of
+      # columns in the validator output
+      context = (file_name, 1, [''] * len(header), header)
+      self._problems.UnrecognizedColumn(file_name, col, context)
+
+    col_index = [-1] * len(cols)
+    for i in range(len(cols)):
+      if cols[i] in header:
+        col_index[i] = header.index(cols[i])
+      elif cols[i] in required:
+        self._problems.MissingColumn(file_name, cols[i])
+
+    row_num = 1
+    for row in reader:
+      row_num += 1
+      if len(row) == 0:  # skip extra empty lines in file
+        continue
+
+      if len(row) > len(header):
+        self._problems.OtherProblem('Found too many cells (commas) in line '
+                                    '%d of file "%s".  Every row in the file '
+                                    'should have the same number of cells as '
+                                    'the header (first line) does.' %
+                                    (row_num, file_name), (file_name, row_num),
+                                    type=problems.TYPE_WARNING)
+
+      if len(row) < len(header):
+        self._problems.OtherProblem('Found missing cells (commas) in line '
+                                    '%d of file "%s".  Every row in the file '
+                                    'should have the same number of cells as '
+                                    'the header (first line) does.' %
+                                    (row_num, file_name), (file_name, row_num),
+                                    type=problems.TYPE_WARNING)
+
+      result = [None] * len(cols)
+      unicode_error_columns = []  # A list of column numbers with an error
+      for i in range(len(cols)):
+        ci = col_index[i]
+        if ci >= 0:
+          if len(row) <= ci:  # handle short CSV rows
+            result[i] = u''
+          else:
+            try:
+              result[i] = row[ci].decode('utf-8').strip()
+            except UnicodeDecodeError:
+              # Replace all invalid characters with
+              # REPLACEMENT CHARACTER (U+FFFD)
+              result[i] = codecs.getdecoder("utf8")(row[ci],
+                                                    errors="replace")[0].strip()
+              unicode_error_columns.append(i)
+
+      for i in unicode_error_columns:
+        self._problems.InvalidValue(cols[i], result[i],
+                                    'Unicode error',
+                                    (file_name, row_num, result, cols))
+      yield (result, row_num, cols)
+
+  def _HasFile(self, file_name):
+    """Returns True if there's a file in the current feed with the
+       given file_name in the current feed."""
+    if self._zip:
+      return file_name in self._zip.namelist()
+    else:
+      file_path = os.path.join(self._path, file_name)
+      return os.path.exists(file_path) and os.path.isfile(file_path)
+
+  def _FileContents(self, file_name):
+    results = None
+    if self._zip:
+      try:
+        results = self._zip.read(file_name)
+      except KeyError:  # file not found in archve
+        self._problems.MissingFile(file_name)
+        return None
+    else:
+      try:
+        data_file = open(os.path.join(self._path, file_name), 'rb')
+        results = data_file.read()
+      except IOError:  # file not found
+        self._problems.MissingFile(file_name)
+        return None
+
+    if not results:
+      self._problems.EmptyFile(file_name)
+    return results
+
+  def _LoadFeed(self):
+    loading_order = self._gtfs_factory.GetLoadingOrder()
+    for filename in loading_order:
+      if not self._gtfs_factory.IsFileRequired(filename) and \
+         not self._HasFile(filename):
+        pass # File is not required, and feed does not have it.
+      else:
+        object_class = self._gtfs_factory.GetGtfsClassByFileName(filename)
+        for (d, row_num, header, row) in self._ReadCsvDict(
+                                       filename,
+                                       object_class._FIELD_NAMES,
+                                       object_class._REQUIRED_FIELD_NAMES):
+          self._problems.SetFileContext(filename, row_num, row, header)
+          instance = object_class(field_dict=d)
+          instance.SetGtfsFactory(self._gtfs_factory)
+          if not instance.ValidateBeforeAdd(self._problems):
+            continue
+          instance.AddToSchedule(self._schedule, self._problems)
+          instance.ValidateAfterAdd(self._problems)
+          self._problems.ClearContext()
+
+  def _LoadCalendar(self):
+    file_name = 'calendar.txt'
+    file_name_dates = 'calendar_dates.txt'
+    if not self._HasFile(file_name) and not self._HasFile(file_name_dates):
+      self._problems.MissingFile(file_name)
+      return
+
+    # map period IDs to (period object, (file_name, row_num, row, cols))
+    periods = {}
+
+    # process calendar.txt
+    if self._HasFile(file_name):
+      has_useful_contents = False
+      for (row, row_num, cols) in \
+          self._ReadCSV(file_name,
+                        self._gtfs_factory.ServicePeriod._FIELD_NAMES,
+                        self._gtfs_factory.ServicePeriod._FIELD_NAMES_REQUIRED):
+        context = (file_name, row_num, row, cols)
+        self._problems.SetFileContext(*context)
+
+        period = self._gtfs_factory.ServicePeriod(field_list=row)
+
+        if period.service_id in periods:
+          self._problems.DuplicateID('service_id', period.service_id)
+        else:
+          periods[period.service_id] = (period, context)
+        self._problems.ClearContext()
+
+    # process calendar_dates.txt
+    if self._HasFile(file_name_dates):
+      # ['service_id', 'date', 'exception_type']
+      fields = self._gtfs_factory.ServicePeriod._FIELD_NAMES_CALENDAR_DATES
+      for (row, row_num, cols) in self._ReadCSV(file_name_dates,
+                                                fields, fields):
+        context = (file_name_dates, row_num, row, cols)
+        self._problems.SetFileContext(*context)
+
+        service_id = row[0]
+
+        period = None
+        if service_id in periods:
+          period = periods[service_id][0]
+        else:
+          period = self._gtfs_factory.ServicePeriod(service_id)
+          periods[period.service_id] = (period, context)
+
+        exception_type = row[2]
+        if exception_type == u'1':
+          period.SetDateHasService(row[1], True, self._problems)
+        elif exception_type == u'2':
+          period.SetDateHasService(row[1], False, self._problems)
+        else:
+          self._problems.InvalidValue('exception_type', exception_type)
+        self._problems.ClearContext()
+
+    # Now insert the periods into the schedule object, so that they're
+    # validated with both calendar and calendar_dates info present
+    for period, context in periods.values():
+      self._problems.SetFileContext(*context)
+      self._schedule.AddServicePeriodObject(period, self._problems)
+      self._problems.ClearContext()
+
+  def _LoadShapes(self):
+    file_name = 'shapes.txt'
+    if not self._HasFile(file_name):
+      return
+    shapes = {}  # shape_id to shape object
+    for (d, row_num, header, row) in self._ReadCsvDict(
+        file_name, 
+        self._gtfs_factory.Shape._FIELD_NAMES,
+        self._gtfs_factory.Shape._REQUIRED_FIELD_NAMES):
+      file_context = (file_name, row_num, row, header)
+      self._problems.SetFileContext(*file_context)
+
+      shapepoint = self._gtfs_factory.ShapePoint(field_dict=d)
+      if not shapepoint.ParseAttributes(self._problems):
+        continue
+
+      if shapepoint.shape_id in shapes:
+        shape = shapes[shapepoint.shape_id]
+      else:
+        shape = self._gtfs_factory.Shape(shapepoint.shape_id)
+        shape.SetGtfsFactory(self._gtfs_factory)
+        shapes[shapepoint.shape_id] = shape
+
+      shape.AddShapePointObjectUnsorted(shapepoint, self._problems)
+      self._problems.ClearContext()
+      
+    for shape_id, shape in shapes.items():
+      self._schedule.AddShapeObject(shape, self._problems)
+      del shapes[shape_id]
+
+  def _LoadStopTimes(self):
+    for (row, row_num, cols) in self._ReadCSV('stop_times.txt',
+        self._gtfs_factory.StopTime._FIELD_NAMES,
+        self._gtfs_factory.StopTime._REQUIRED_FIELD_NAMES):
+      file_context = ('stop_times.txt', row_num, row, cols)
+      self._problems.SetFileContext(*file_context)
+
+      (trip_id, arrival_time, departure_time, stop_id, stop_sequence,
+         stop_headsign, pickup_type, drop_off_type, shape_dist_traveled) = row
+
+      try:
+        sequence = int(stop_sequence)
+      except (TypeError, ValueError):
+        self._problems.InvalidValue('stop_sequence', stop_sequence,
+                                    'This should be a number.')
+        continue
+      if sequence < 0:
+        self._problems.InvalidValue('stop_sequence', sequence,
+                                    'Sequence numbers should be 0 or higher.')
+
+      if stop_id not in self._schedule.stops:
+        self._problems.InvalidValue('stop_id', stop_id,
+                                    'This value wasn\'t defined in stops.txt')
+        continue
+      stop = self._schedule.stops[stop_id]
+      if trip_id not in self._schedule.trips:
+        self._problems.InvalidValue('trip_id', trip_id,
+                                    'This value wasn\'t defined in trips.txt')
+        continue
+      trip = self._schedule.trips[trip_id]
+
+      # If self._problems.Report returns then StopTime.__init__ will return
+      # even if the StopTime object has an error. Thus this code may add a
+      # StopTime that didn't validate to the database.
+      # Trip.GetStopTimes then tries to make a StopTime from the invalid data
+      # and calls the problem reporter for errors. An ugly solution is to
+      # wrap problems and a better solution is to move all validation out of
+      # __init__. For now make sure Trip.GetStopTimes gets a problem reporter
+      # when called from Trip.Validate.
+      stop_time = self._gtfs_factory.StopTime(self._problems, stop,
+          arrival_time, departure_time, stop_headsign, pickup_type,
+          drop_off_type, shape_dist_traveled, stop_sequence=sequence)
+      trip._AddStopTimeObjectUnordered(stop_time, self._schedule)
+      self._problems.ClearContext()
+
+    # stop_times are validated in Trip.ValidateChildren, called by
+    # Schedule.Validate
+
+  def Load(self):
+    self._problems.ClearContext()
+    if not self._DetermineFormat():
+      return self._schedule
+
+    self._CheckFileNames()
+    self._LoadCalendar()
+    self._LoadShapes()
+    self._LoadFeed()
+    
+    if self._load_stop_times:
+      self._LoadStopTimes()
+
+    if self._zip:
+      self._zip.close()
+      self._zip = None
+
+    if self._extra_validation:
+      self._schedule.Validate(self._problems, validate_children=False)
+
+    return self._schedule
+