Upgrade origin-src to google transit feed 1.2.6
[bus.git] / origin-src / transitfeed-1.2.6 / transitfeed / loader.py
blob:a/origin-src/transitfeed-1.2.6/transitfeed/loader.py -> blob:b/origin-src/transitfeed-1.2.6/transitfeed/loader.py
  #!/usr/bin/python2.5
   
  # Copyright (C) 2007 Google Inc.
  #
  # Licensed under the Apache License, Version 2.0 (the "License");
  # you may not use this file except in compliance with the License.
  # You may obtain a copy of the License at
  #
  # http://www.apache.org/licenses/LICENSE-2.0
  #
  # Unless required by applicable law or agreed to in writing, software
  # distributed under the License is distributed on an "AS IS" BASIS,
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  # See the License for the specific language governing permissions and
  # limitations under the License.
   
  import codecs
  import cStringIO as StringIO
  import csv
  import os
  import re
  import zipfile
   
  import gtfsfactory as gtfsfactory_module
  import problems
  import util
   
  class Loader:
  def __init__(self,
  feed_path=None,
  schedule=None,
  problems=problems.default_problem_reporter,
  extra_validation=False,
  load_stop_times=True,
  memory_db=True,
  zip=None,
  check_duplicate_trips=False,
  gtfs_factory=None):
  """Initialize a new Loader object.
   
  Args:
  feed_path: string path to a zip file or directory
  schedule: a Schedule object or None to have one created
  problems: a ProblemReporter object, the default reporter raises an
  exception for each problem
  extra_validation: True if you would like extra validation
  load_stop_times: load the stop_times table, used to speed load time when
  times are not needed. The default is True.
  memory_db: if creating a new Schedule object use an in-memory sqlite
  database instead of creating one in a temporary file
  zip: a zipfile.ZipFile object, optionally used instead of path
  """
  if gtfs_factory is None:
  gtfs_factory = gtfsfactory_module.GetGtfsFactory()
   
  if not schedule:
  schedule = gtfs_factory.Schedule(problem_reporter=problems,
  memory_db=memory_db, check_duplicate_trips=check_duplicate_trips)
   
  self._extra_validation = extra_validation
  self._schedule = schedule
  self._problems = problems
  self._path = feed_path
  self._zip = zip
  self._load_stop_times = load_stop_times
  self._gtfs_factory = gtfs_factory
   
  def _DetermineFormat(self):
  """Determines whether the feed is in a form that we understand, and
  if so, returns True."""
  if self._zip:
  # If zip was passed to __init__ then path isn't used
  assert not self._path
  return True
   
  if not isinstance(self._path, basestring) and hasattr(self._path, 'read'):
  # A file-like object, used for testing with a StringIO file
  self._zip = zipfile.ZipFile(self._path, mode='r')
  return True
   
  if not os.path.exists(self._path):
  self._problems.FeedNotFound(self._path)
  return False
   
  if self._path.endswith('.zip'):
  try:
  self._zip = zipfile.ZipFile(self._path, mode='r')
  except IOError: # self._path is a directory
  pass
  except zipfile.BadZipfile:
  self._problems.UnknownFormat(self._path)
  return False
   
  if not self._zip and not os.path.isdir(self._path):
  self._problems.UnknownFormat(self._path)
  return False
   
  return True
   
  def _GetFileNames(self):
  """Returns a list of file names in the feed."""
  if self._zip:
  return self._zip.namelist()
  else:
  return os.listdir(self._path)
   
  def _CheckFileNames(self):
  filenames = self._GetFileNames()
  known_filenames = self._gtfs_factory.GetKnownFilenames()
  for feed_file in filenames:
  if feed_file not in known_filenames:
  if not feed_file.startswith('.'):
  # Don't worry about .svn files and other hidden files
  # as this will break the tests.
  self._problems.UnknownFile(feed_file)
   
  def _GetUtf8Contents(self, file_name):
  """Check for errors in file_name and return a string for csv reader."""
  contents = self._FileContents(file_name)
  if not contents: # Missing file
  return
   
  # Check for errors that will prevent csv.reader from working
  if len(contents) >= 2 and contents[0:2] in (codecs.BOM_UTF16_BE,
  codecs.BOM_UTF16_LE):
  self._problems.FileFormat("appears to be encoded in utf-16", (file_name, ))
  # Convert and continue, so we can find more errors
  contents = codecs.getdecoder('utf-16')(contents)[0].encode('utf-8')
   
  null_index = contents.find('\0')
  if null_index != -1:
  # It is easier to get some surrounding text than calculate the exact
  # row_num
  m = re.search(r'.{,20}\0.{,20}', contents, re.DOTALL)
  self._problems.FileFormat(
  "contains a null in text \"%s\" at byte %d" %
  (codecs.getencoder('string_escape')(m.group()), null_index + 1),
  (file_name, ))
  return
   
  # strip out any UTF-8 Byte Order Marker (otherwise it'll be
  # treated as part of the first column name, causing a mis-parse)
  contents = contents.lstrip(codecs.BOM_UTF8)
  return contents
   
  def _ReadCsvDict(self, file_name, all_cols, required):
  """Reads lines from file_name, yielding a dict of unicode values."""
  assert file_name.endswith(".txt")
  table_name = file_name[0:-4]
  contents = self._GetUtf8Contents(file_name)
  if not contents:
  return
   
  eol_checker = util.EndOfLineChecker(StringIO.StringIO(contents),
  file_name, self._problems)
  # The csv module doesn't provide a way to skip trailing space, but when I
  # checked 15/675 feeds had trailing space in a header row and 120 had spaces
  # after fields. Space after header fields can cause a serious parsing
  # problem, so warn. Space after body fields can cause a problem time,
  # integer and id fields; they will be validated at higher levels.
  reader = csv.reader(eol_checker, skipinitialspace=True)
   
  raw_header = reader.next()
  header_occurrences = util.defaultdict(lambda: 0)
  header = []
  valid_columns = [] # Index into raw_header and raw_row
  for i, h in enumerate(raw_header):
  h_stripped = h.strip()
  if not h_stripped:
  self._problems.CsvSyntax(
  description="The header row should not contain any blank values. "
  "The corresponding column will be skipped for the "
  "entire file.",
  context=(file_name, 1, [''] * len(raw_header), raw_header),
  type=problems.TYPE_ERROR)
  continue
  elif h != h_stripped:
  self._problems.CsvSyntax(
  description="The header row should not contain any "
  "space characters.",
  context=(file_name, 1, [''] * len(raw_header), raw_header),
  type=problems.TYPE_WARNING)
  header.append(h_stripped)
  valid_columns.append(i)
  header_occurrences[h_stripped] += 1
   
  for name, count in header_occurrences.items():
  if count > 1:
  self._problems.DuplicateColumn(
  header=name,
  file_name=file_name,
  count=count)
   
  self._schedule._table_columns[table_name] = header
   
  # check for unrecognized columns, which are often misspellings
  unknown_cols = set(header) - set(all_cols)
  if len(unknown_cols) == len(header):
  self._problems.CsvSyntax(
  description="The header row did not contain any known column "
  "names. The file is most likely missing the header row "
  "or not in the expected CSV format.",
  context=(file_name, 1, [''] * len(raw_header), raw_header),
  type=problems.TYPE_ERROR)
  else:
  for col in unknown_cols:
  # this is provided in order to create a nice colored list of
  # columns in the validator output
  context = (file_name, 1, [''] * len(header), header)
  self._problems.UnrecognizedColumn(file_name, col, context)
   
  missing_cols = set(required) - set(header)
  for col in missing_cols:
  # this is provided in order to create a nice colored list of
  # columns in the validator output
  context = (file_name, 1, [''] * len(header), header)
  self._problems.MissingColumn(file_name, col, context)
   
  line_num = 1 # First line read by reader.next() above
  for raw_row in reader:
  line_num += 1
  if len(raw_row) == 0: # skip extra empty lines in file
  continue
   
  if len(raw_row) > len(raw_header):
  self._problems.OtherProblem('Found too many cells (commas) in line '
  '%d of file "%s". Every row in the file '
  'should have the same number of cells as '
  'the header (first line) does.' %
  (line_num, file_name),
  (file_name, line_num),
  type=problems.TYPE_WARNING)
   
  if len(raw_row) < len(raw_header):
  self._problems.OtherProblem('Found missing cells (commas) in line '
  '%d of file "%s". Every row in the file '
  'should have the same number of cells as '
  'the header (first line) does.' %
  (line_num, file_name),
  (file_name, line_num),
  type=problems.TYPE_WARNING)
   
  # raw_row is a list of raw bytes which should be valid utf-8. Convert each
  # valid_columns of raw_row into Unicode.
  valid_values = []
  unicode_error_columns = [] # index of valid_values elements with an error
  for i in valid_columns:
  try:
  valid_values.append(raw_row[i].decode('utf-8'))
  except UnicodeDecodeError:
  # Replace all invalid characters with REPLACEMENT CHARACTER (U+FFFD)
  valid_values.append(codecs.getdecoder("utf8")
  (raw_row[i], errors="replace")[0])
  unicode_error_columns.append(len(valid_values) - 1)
  except IndexError:
  break
   
  # The error report may contain a dump of all values in valid_values so
  # problems can not be reported until after converting all of raw_row to
  # Unicode.
  for i in unicode_error_columns:
  self._problems.InvalidValue(header[i], valid_values[i],
  'Unicode error',
  (file_name, line_num,
  valid_values, header))
   
   
  d = dict(zip(header, valid_values))
  yield (d, line_num, header, valid_values)
   
  # TODO: Add testing for this specific function
  def _ReadCSV(self, file_name, cols, required):
  """Reads lines from file_name, yielding a list of unicode values
  corresponding to the column names in cols."""
  contents = self._GetUtf8Contents(file_name)
  if not contents:
  return
   
  eol_checker = util.EndOfLineChecker(StringIO.StringIO(contents),
  file_name, self._problems)
  reader = csv.reader(eol_checker) # Use excel dialect
   
  header = reader.next()
  header = map(lambda x: x.strip(), header) # trim any whitespace
  header_occurrences = util.defaultdict(lambda: 0)
  for column_header in header:
  header_occurrences[column_header] += 1
   
  for name, count in header_occurrences.items():
  if count > 1:
  self._problems.DuplicateColumn(
  header=name,
  file_name=file_name,
  count=count)
   
  # check for unrecognized columns, which are often misspellings
  unknown_cols = set(header).difference(set(cols))
  for col in unknown_cols:
  # this is provided in order to create a nice colored list of
  # columns in the validator output
  context = (file_name, 1, [''] * len(header), header)
  self._problems.UnrecognizedColumn(file_name, col, context)
   
  col_index = [-1] * len(cols)
  for i in range(len(cols)):
  if cols[i] in header:
  col_index[i] = header.index(cols[i])
  elif cols[i] in required:
  self._problems.MissingColumn(file_name, cols[i])
   
  row_num = 1
  for row in reader:
  row_num += 1
  if len(row) == 0: # skip extra empty lines in file
  continue
   
  if len(row) > len(header):
  self._problems.OtherProblem('Found too many cells (commas) in line '
  '%d of file "%s". Every row in the file '
  'should have the same number of cells as '
  'the header (first line) does.' %
  (row_num, file_name), (file_name, row_num),
  type=problems.TYPE_WARNING)
   
  if len(row) < len(header):
  self._problems.OtherProblem('Found missing cells (commas) in line '
  '%d of file "%s". Every row in the file '
  'should have the same number of cells as '
  'the header (first line) does.' %
  (row_num, file_name), (file_name, row_num),
  type=problems.TYPE_WARNING)
   
  result = [None] * len(cols)
  unicode_error_columns = [] # A list of column numbers with an error
  for i in range(len(cols)):
  ci = col_index[i]
  if ci >= 0:
  if len(row) <= ci: # handle short CSV rows
  result[i] = u''
  else:
  try:
  result[i] = row[ci].decode('utf-8').strip()
  except UnicodeDecodeError:
  # Replace all invalid characters with
  # REPLACEMENT CHARACTER (U+FFFD)
  result[i] = codecs.getdecoder("utf8")(row[ci],
  errors="replace")[0].strip()
  unicode_error_columns.append(i)
   
  for i in unicode_error_columns:
  self._problems.InvalidValue(cols[i], result[i],
  'Unicode error',
  (file_name, row_num, result, cols))
  yield (result, row_num, cols)
   
  def _HasFile(self, file_name):
  """Returns True if there's a fi