Upgrade origin-src to google transit feed 1.2.6
[bus.git] / origin-src / transitfeed-1.2.6 / transitfeed / trip.py
blob:a/origin-src/transitfeed-1.2.6/transitfeed/trip.py -> blob:b/origin-src/transitfeed-1.2.6/transitfeed/trip.py
  #!/usr/bin/python2.5
   
  # Copyright (C) 2007 Google Inc.
  #
  # Licensed under the Apache License, Version 2.0 (the "License");
  # you may not use this file except in compliance with the License.
  # You may obtain a copy of the License at
  #
  # http://www.apache.org/licenses/LICENSE-2.0
  #
  # Unless required by applicable law or agreed to in writing, software
  # distributed under the License is distributed on an "AS IS" BASIS,
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  # See the License for the specific language governing permissions and
  # limitations under the License.
   
  import warnings
   
  from gtfsobjectbase import GtfsObjectBase
  import problems as problems_module
  import util
   
  class Trip(GtfsObjectBase):
  _REQUIRED_FIELD_NAMES = ['route_id', 'service_id', 'trip_id']
  _FIELD_NAMES = _REQUIRED_FIELD_NAMES + [
  'trip_headsign', 'direction_id', 'block_id', 'shape_id'
  ]
  _TABLE_NAME= "trips"
   
  def __init__(self, headsign=None, service_period=None,
  route=None, trip_id=None, field_dict=None):
  self._schedule = None
  self._headways = [] # [(start_time, end_time, headway_secs)]
  if not field_dict:
  field_dict = {}
  if headsign is not None:
  field_dict['trip_headsign'] = headsign
  if route:
  field_dict['route_id'] = route.route_id
  if trip_id is not None:
  field_dict['trip_id'] = trip_id
  if service_period is not None:
  field_dict['service_id'] = service_period.service_id
  # Earlier versions of transitfeed.py assigned self.service_period here
  # and allowed the caller to set self.service_id. Schedule.Validate
  # checked the service_id attribute if it was assigned and changed it to a
  # service_period attribute. Now only the service_id attribute is used and
  # it is validated by Trip.Validate.
  if service_period is not None:
  # For backwards compatibility
  self.service_id = service_period.service_id
  self.__dict__.update(field_dict)
   
  def GetFieldValuesTuple(self):
  return [getattr(self, fn) or '' for fn in self._FIELD_NAMES]
   
  def AddStopTime(self, stop, problems=None, schedule=None, **kwargs):
  """Add a stop to this trip. Stops must be added in the order visited.
   
  Args:
  stop: A Stop object
  kwargs: remaining keyword args passed to StopTime.__init__
   
  Returns:
  None
  """
  if problems is None:
  # TODO: delete this branch when StopTime.__init__ doesn't need a
  # ProblemReporter
  problems = problems_module.default_problem_reporter
  stoptime = self.GetGtfsFactory().StopTime(
  problems=problems, stop=stop, **kwargs)
  self.AddStopTimeObject(stoptime, schedule)
   
  def _AddStopTimeObjectUnordered(self, stoptime, schedule):
  """Add StopTime object to this trip.
   
  The trip isn't checked for duplicate sequence numbers so it must be
  validated later."""
  stop_time_class = self.GetGtfsFactory().StopTime
  cursor = schedule._connection.cursor()
  insert_query = "INSERT INTO stop_times (%s) VALUES (%s);" % (
  ','.join(stop_time_class._SQL_FIELD_NAMES),
  ','.join(['?'] * len(stop_time_class._SQL_FIELD_NAMES)))
  cursor = schedule._connection.cursor()
  cursor.execute(
  insert_query, stoptime.GetSqlValuesTuple(self.trip_id))
   
  def ReplaceStopTimeObject(self, stoptime, schedule=None):
  """Replace a StopTime object from this trip with the given one.
   
  Keys the StopTime object to be replaced by trip_id, stop_sequence
  and stop_id as 'stoptime', with the object 'stoptime'.
  """
   
  if schedule is None:
  schedule = self._schedule
   
  new_secs = stoptime.GetTimeSecs()
  cursor = schedule._connection.cursor()
  cursor.execute("DELETE FROM stop_times WHERE trip_id=? and "
  "stop_sequence=? and stop_id=?",
  (self.trip_id, stoptime.stop_sequence, stoptime.stop_id))
  if cursor.rowcount == 0:
  raise problems_module.Error, 'Attempted replacement of StopTime object which does not exist'
  self._AddStopTimeObjectUnordered(stoptime, schedule)
   
  def AddStopTimeObject(self, stoptime, schedule=None, problems=None):
  """Add a StopTime object to the end of this trip.
   
  Args:
  stoptime: A StopTime object. Should not be reused in multiple trips.
  schedule: Schedule object containing this trip which must be
  passed to Trip.__init__ or here
  problems: ProblemReporter object for validating the StopTime in its new
  home
   
  Returns:
  None
  """
  if schedule is None:
  schedule = self._schedule
  if schedule is None:
  warnings.warn("No longer supported. _schedule attribute is used to get "
  "stop_times table", DeprecationWarning)
  if problems is None:
  problems = schedule.problem_reporter
   
  new_secs = stoptime.GetTimeSecs()
  cursor = schedule._connection.cursor()
  cursor.execute("SELECT max(stop_sequence), max(arrival_secs), "
  "max(departure_secs) FROM stop_times WHERE trip_id=?",
  (self.trip_id,))
  row = cursor.fetchone()
  if row[0] is None:
  # This is the first stop_time of the trip
  stoptime.stop_sequence = 1
  if new_secs == None:
  problems.OtherProblem(
  'No time for first StopTime of trip_id "%s"' % (self.trip_id,))
  else:
  stoptime.stop_sequence = row[0] + 1
  prev_secs = max(row[1], row[2])
  if new_secs != None and new_secs < prev_secs:
  problems.OtherProblem(
  'out of order stop time for stop_id=%s trip_id=%s %s < %s' %
  (util.EncodeUnicode(stoptime.stop_id),
  util.EncodeUnicode(self.trip_id),
  util.FormatSecondsSinceMidnight(new_secs),
  util.FormatSecondsSinceMidnight(prev_secs)))
  self._AddStopTimeObjectUnordered(stoptime, schedule)
   
  def GetTimeStops(self):
  """Return a list of (arrival_secs, departure_secs, stop) tuples.
   
  Caution: arrival_secs and departure_secs may be 0, a false value meaning a
  stop at midnight or None, a false value meaning the stop is untimed."""
  return [(st.arrival_secs, st.departure_secs, st.stop) for st in
  self.GetStopTimes()]
   
  def GetCountStopTimes(self):
  """Return the number of stops made by this trip."""
  cursor = self._schedule._connection.cursor()
  cursor.execute(
  'SELECT count(*) FROM stop_times WHERE trip_id=?', (self.trip_id,))
  return cursor.fetchone()[0]
   
  def GetTimeInterpolatedStops(self):
  """Return a list of (secs, stoptime, is_timepoint) tuples.
   
  secs will always be an int. If the StopTime object does not have explict
  times this method guesses using distance. stoptime is a StopTime object and
  is_timepoint is a bool.
   
  Raises:
  ValueError if this trip does not have the times needed to interpolate
  """
  rv = []
   
  stoptimes = self.GetStopTimes()
  # If there are no stoptimes [] is the correct return value but if the start
  # or end are missing times there is no correct return value.
  if not stoptimes:
  return []
  if (stoptimes[0].GetTimeSecs() is None or
  stoptimes[-1].GetTimeSecs() is None):
  raise ValueError("%s must have time at first and last stop" % (self))
   
  cur_timepoint = None
  next_timepoint = None
  distance_between_timepoints = 0
  distance_traveled_between_timepoints = 0
   
  for i, st in enumerate(stoptimes):
  if st.GetTimeSecs() != None:
  cur_timepoint = st
  distance_between_timepoints = 0
  distance_traveled_between_timepoints = 0
  if i + 1 < len(stoptimes):
  k = i + 1
  distance_between_timepoints += util.ApproximateDistanceBetweenStops(stoptimes[k-1].stop, stoptimes[k].stop)
  while stoptimes[k].GetTimeSecs() == None:
  k += 1
  distance_between_timepoints += util.ApproximateDistanceBetweenStops(stoptimes[k-1].stop, stoptimes[k].stop)
  next_timepoint = stoptimes[k]
  rv.append( (st.GetTimeSecs(), st, True) )
  else:
  distance_traveled_between_timepoints += util.ApproximateDistanceBetweenStops(stoptimes[i-1].stop, st.stop)
  distance_percent = distance_traveled_between_timepoints / distance_between_timepoints
  total_time = next_timepoint.GetTimeSecs() - cur_timepoint.GetTimeSecs()
  time_estimate = distance_percent * total_time + cur_timepoint.GetTimeSecs()
  rv.append( (int(round(time_estimate)), st, False) )
   
  return rv
   
  def ClearStopTimes(self):
  """Remove all stop times from this trip.
   
  StopTime objects previously returned by GetStopTimes are unchanged but are
  no longer associated with this trip.
  """
  cursor = self._schedule._connection.cursor()
  cursor.execute('DELETE FROM stop_times WHERE trip_id=?', (self.trip_id,))
   
  def GetStopTimes(self, problems=None):
  """Return a sorted list of StopTime objects for this trip."""
  # In theory problems=None should be safe because data from database has been
  # validated. See comment in _LoadStopTimes for why this isn't always true.
  cursor = self._schedule._connection.cursor()
  cursor.execute(
  'SELECT arrival_secs,departure_secs,stop_headsign,pickup_type,'
  'drop_off_type,shape_dist_traveled,stop_id,stop_sequence FROM '
  'stop_times WHERE '
  'trip_id=? ORDER BY stop_sequence', (self.trip_id,))
  stop_times = []
  stoptime_class = self.GetGtfsFactory().StopTime
  for row in cursor.fetchall():
  stop = self._schedule.GetStop(row[6])
  stop_times.append(stoptime_class(problems=problems,
  stop=stop,
  arrival_secs=row[0],
  departure_secs=row[1],
  stop_headsign=row[2],
  pickup_type=row[3],
  drop_off_type=row[4],
  shape_dist_traveled=row[5],
  stop_sequence=row[7]))
  return stop_times
   
  def GetHeadwayStopTimes(self, problems=None):
  """Deprecated. Please use GetFrequencyStopTimes instead."""
  warnings.warn("No longer supported. The HeadwayPeriod class was renamed to "
  "Frequency, and all related functions were renamed "
  "accordingly.", DeprecationWarning)
  return self.GetFrequencyStopTimes(problems)
   
  def GetFrequencyStopTimes(self, problems=None):
  """Return a list of StopTime objects for each headway-based run.
   
  Returns:
  a list of list of StopTime objects. Each list of StopTime objects
  represents one run. If this trip doesn't have headways returns an empty
  list.
  """
  stoptimes_list = [] # list of stoptime lists to be returned
  stoptime_pattern = self.GetStopTimes()
  first_secs = stoptime_pattern[0].arrival_secs # first time of the trip
  stoptime_class = self.GetGtfsFactory().StopTime
  # for each start time of a headway run
  for run_secs in self.GetFrequencyStartTimes():
  # stop time list for a headway run
  stoptimes = []
  # go through the pattern and generate stoptimes
  for st in stoptime_pattern:
  arrival_secs, departure_secs = None, None # default value if the stoptime is not timepoint
  if st.arrival_secs != None:
  arrival_secs = st.arrival_secs - first_secs + run_secs
  if st.departure_secs != None:
  departure_secs = st.departure_secs - first_secs + run_secs
  # append stoptime
  stoptimes.append(stoptime_class(problems=problems, stop=st.stop,
  arrival_secs=arrival_secs,
  departure_secs=departure_secs,
  stop_headsign=st.stop_headsign,
  pickup_type=st.pickup_type,
  drop_off_type=st.drop_off_type,
  shape_dist_traveled= \
  st.shape_dist_traveled,
  stop_sequence=st.stop_sequence))
  # add stoptimes to the stoptimes_list
  stoptimes_list.append ( stoptimes )
  return stoptimes_list
   
  def GetStartTime(self, problems=problems_module.default_problem_reporter):
  """Return the first time of the trip. TODO: For trips defined by frequency
  return the first time of the first trip."""
  cursor = self._schedule._connection.cursor()
  cursor.execute(
  'SELECT arrival_secs,departure_secs FROM stop_times WHERE '
  'trip_id=? ORDER BY stop_sequence LIMIT 1', (self.trip_id,))
  (arrival_secs, departure_secs) = cursor.fetchone()
  if arrival_secs != None:
  return arrival_secs
  elif departure_secs != None:
  return departure_secs
  else:
  problems.InvalidValue('departure_time', '',
  'The first stop_time in trip %s is missing '
  'times.' % self.trip_id)
   
  def GetHeadwayStartTimes(self):
  """Deprecated. Please use GetFrequencyStartTimes instead."""
  warnings.warn("No longer supported. The HeadwayPeriod class was renamed to "
  "Frequency, and all related functions were renamed "
  "accordingly.", DeprecationWarning)
  return self.GetFrequencyStartTimes()
   
  def GetFrequencyStartTimes(self):
  """Return a list of start time for each headway-based run.
   
  Returns:
  a sorted list of seconds since midnight, the start time of each run. If
  this trip doesn't have headways returns an empty list."""
  start_times = []
  # for each headway period of the trip
  for start_secs, end_secs, headway_secs in self.GetFrequencyTuples():
  # reset run secs to the start of the timeframe
  run_secs = start_secs
  while run_secs < end_secs:
  start_times.append(run_secs)
  # increment current run secs by headway secs
  run_secs += headway_secs
  return start_times
   
  def GetEndTime(self, problems=problems_module.default_problem_reporter):
  """Return the last time of the trip. TODO: For trips defined by frequency
  return the last time of the last trip."""
  cursor = self._schedule._connection.cursor()
  cursor.execute(
  'SELECT arrival_secs,departure_secs FROM stop_times WHERE '
  'trip_id=? ORDER BY stop_sequence DESC LIMIT 1', (self.trip_id,))
  (arrival_secs, departure_secs) = cursor.fetchone()
  if departure_secs != None:
  return departure_secs
  elif arrival_secs != None:
  return arrival_secs
  else:
  problems.InvalidValue('arrival_time', '',
  'The last stop_time in trip %s is missing '
  'times.' % self.trip_id)