More between points work + mobile UI work:
[bus.git] / busui / schedule_viewer.py
blob:a/busui/schedule_viewer.py -> blob:b/busui/schedule_viewer.py
--- a/busui/schedule_viewer.py
+++ b/busui/schedule_viewer.py
@@ -1,1 +1,536 @@
-
+#!/usr/bin/python2.5
+
+# Copyright (C) 2007 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+An example application that uses the transitfeed module.
+
+You must provide a Google Maps API key.
+"""
+
+
+import BaseHTTPServer, sys, urlparse
+import bisect
+from gtfsscheduleviewer.marey_graph import MareyGraph
+import gtfsscheduleviewer
+import mimetypes
+import os.path
+import re
+import signal
+import simplejson
+import socket
+import time
+import transitfeed
+from transitfeed import util
+import urllib
+
+
+# By default Windows kills Python with Ctrl+Break. Instead make Ctrl+Break
+# raise a KeyboardInterrupt.
+if hasattr(signal, 'SIGBREAK'):
+  signal.signal(signal.SIGBREAK, signal.default_int_handler)
+
+
+mimetypes.add_type('text/plain', '.vbs')
+
+
+class ResultEncoder(simplejson.JSONEncoder):
+  def default(self, obj):
+    try:
+      iterable = iter(obj)
+    except TypeError:
+      pass
+    else:
+      return list(iterable)
+    return simplejson.JSONEncoder.default(self, obj)
+
+# Code taken from
+# http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/425210/index_txt
+# An alternate approach is shown at
+# http://mail.python.org/pipermail/python-list/2003-July/212751.html
+# but it requires multiple threads. A sqlite object can only be used from one
+# thread.
+class StoppableHTTPServer(BaseHTTPServer.HTTPServer):
+  def server_bind(self):
+    BaseHTTPServer.HTTPServer.server_bind(self)
+    self.socket.settimeout(1)
+    self._run = True
+
+  def get_request(self):
+    while self._run:
+      try:
+        sock, addr = self.socket.accept()
+        sock.settimeout(None)
+        return (sock, addr)
+      except socket.timeout:
+        pass
+
+  def stop(self):
+    self._run = False
+
+  def serve(self):
+    while self._run:
+      self.handle_request()
+
+
+def StopToTuple(stop):
+  """Return tuple as expected by javascript function addStopMarkerFromList"""
+  return (stop.stop_id, stop.stop_name, float(stop.stop_lat),
+          float(stop.stop_lon), stop.location_type)
+
+
+class ScheduleRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
+  def do_GET(self):
+    scheme, host, path, x, params, fragment = urlparse.urlparse(self.path)
+    parsed_params = {}
+    for k in params.split('&'):
+      k = urllib.unquote(k)
+      if '=' in k:
+        k, v = k.split('=', 1)
+        parsed_params[k] = unicode(v, 'utf8')
+      else:
+        parsed_params[k] = ''
+
+    if path == '/':
+      return self.handle_GET_home()
+
+    m = re.match(r'/json/([a-z]{1,64})', path)
+    if m:
+      handler_name = 'handle_json_GET_%s' % m.group(1)
+      handler = getattr(self, handler_name, None)
+      if callable(handler):
+        return self.handle_json_wrapper_GET(handler, parsed_params)
+
+    # Restrict allowable file names to prevent relative path attacks etc
+    m = re.match(r'/file/([a-z0-9_-]{1,64}\.?[a-z0-9_-]{1,64})$', path)
+    if m and m.group(1):
+      try:
+        f, mime_type = self.OpenFile(m.group(1))
+        return self.handle_static_file_GET(f, mime_type)
+      except IOError, e:
+        print "Error: unable to open %s" % m.group(1)
+        # Ignore and treat as 404
+
+    m = re.match(r'/([a-z]{1,64})', path)
+    if m:
+      handler_name = 'handle_GET_%s' % m.group(1)
+      handler = getattr(self, handler_name, None)
+      if callable(handler):
+        return handler(parsed_params)
+
+    return self.handle_GET_default(parsed_params, path)
+
+  def OpenFile(self, filename):
+    """Try to open filename in the static files directory of this server.
+    Return a tuple (file object, string mime_type) or raise an exception."""
+    (mime_type, encoding) = mimetypes.guess_type(filename)
+    assert mime_type
+    # A crude guess of when we should use binary mode. Without it non-unix
+    # platforms may corrupt binary files.
+    if mime_type.startswith('text/'):
+      mode = 'r'
+    else:
+      mode = 'rb'
+    return open(os.path.join(self.server.file_dir, filename), mode), mime_type
+
+  def handle_GET_default(self, parsed_params, path):
+    self.send_error(404)
+
+  def handle_static_file_GET(self, fh, mime_type):
+    content = fh.read()
+    self.send_response(200)
+    self.send_header('Content-Type', mime_type)
+    self.send_header('Content-Length', str(len(content)))
+    self.end_headers()
+    self.wfile.write(content)
+
+  def AllowEditMode(self):
+    return False
+
+  def handle_GET_home(self):
+    schedule = self.server.schedule
+    (min_lat, min_lon, max_lat, max_lon) = schedule.GetStopBoundingBox()
+    forbid_editing = ('true', 'false')[self.AllowEditMode()]
+
+    agency = ', '.join(a.agency_name for a in schedule.GetAgencyList()).encode('utf-8')
+
+    key = self.server.key
+    host = self.server.host
+
+    # A very simple template system. For a fixed set of values replace [xxx]
+    # with the value of local variable xxx
+    f, _ = self.OpenFile('index.html')
+    content = f.read()
+    for v in ('agency', 'min_lat', 'min_lon', 'max_lat', 'max_lon', 'key',
+              'host', 'forbid_editing'):
+      content = content.replace('[%s]' % v, str(locals()[v]))
+
+    self.send_response(200)
+    self.send_header('Content-Type', 'text/html')
+    self.send_header('Content-Length', str(len(content)))
+    self.end_headers()
+    self.wfile.write(content)
+
+  def handle_json_GET_routepatterns(self, params):
+    """Given a route_id generate a list of patterns of the route. For each
+    pattern include some basic information and a few sample trips."""
+    schedule = self.server.schedule
+    route = schedule.GetRoute(params.get('route', None))
+    if not route:
+      self.send_error(404)
+      return
+    time = int(params.get('time', 0))
+    sample_size = 10  # For each pattern return the start time for this many trips
+
+    pattern_id_trip_dict = route.GetPatternIdTripDict()
+    patterns = []
+
+    for pattern_id, trips in pattern_id_trip_dict.items():
+      time_stops = trips[0].GetTimeStops()
+      if not time_stops:
+        continue
+      has_non_zero_trip_type = False;
+      for trip in trips:
+        if trip['trip_type'] and trip['trip_type'] != '0':
+          has_non_zero_trip_type = True
+      name = u'%s to %s, %d stops' % (time_stops[0][2].stop_name, time_stops[-1][2].stop_name, len(time_stops))
+      transitfeed.SortListOfTripByTime(trips)
+
+      num_trips = len(trips)
+      if num_trips <= sample_size:
+        start_sample_index = 0
+        num_after_sample = 0
+      else:
+        # Will return sample_size trips that start after the 'time' param.
+
+        # Linear search because I couldn't find a built-in way to do a binary
+        # search with a custom key.
+        start_sample_index = len(trips)
+        for i, trip in enumerate(trips):
+          if trip.GetStartTime() >= time:
+            start_sample_index = i
+            break
+
+        num_after_sample = num_trips - (start_sample_index + sample_size)
+        if num_after_sample < 0:
+          # Less than sample_size trips start after 'time' so return all the
+          # last sample_size trips.
+          num_after_sample = 0
+          start_sample_index = num_trips - sample_size
+
+      sample = []
+      for t in trips[start_sample_index:start_sample_index + sample_size]:
+        sample.append( (t.GetStartTime(), t.trip_id) )
+
+      patterns.append((name, pattern_id, start_sample_index, sample,
+                       num_after_sample, (0,1)[has_non_zero_trip_type]))
+
+    patterns.sort()
+    return patterns
+
+  def handle_json_wrapper_GET(self, handler, parsed_params):
+    """Call handler and output the return value in JSON."""
+    schedule = self.server.schedule
+    result = handler(parsed_params)
+    content = ResultEncoder().encode(result)
+    self.send_response(200)
+    self.send_header('Content-Type', 'text/plain')
+    self.send_header('Content-Length', str(len(content)))
+    self.end_headers()
+    self.wfile.write(content)
+
+  def handle_json_GET_routes(self, params):
+    """Return a list of all routes."""
+    schedule = self.server.schedule
+    result = []
+    for r in schedule.GetRouteList():
+      result.append( (r.route_id, r.route_short_name, r.route_long_name) )
+    result.sort(key = lambda x: x[1:3])
+    return result
+
+  def handle_json_GET_routerow(self, params):
+    schedule = self.server.schedule
+    route = schedule.GetRoute(params.get('route', None))
+    return [transitfeed.Route._FIELD_NAMES, route.GetFieldValuesTuple()]
+
+  def handle_json_GET_triprows(self, params):
+    """Return a list of rows from the feed file that are related to this
+    trip."""
+    schedule = self.server.schedule
+    try:
+      trip = schedule.GetTrip(params.get('trip', None))
+    except KeyError:
+      # if a non-existent trip is searched for, the return nothing
+      return
+    route = schedule.GetRoute(trip.route_id)
+    trip_row = dict(trip.iteritems())
+    route_row = dict(route.iteritems())
+    return [['trips.txt', trip_row], ['routes.txt', route_row]]
+
+  def handle_json_GET_tripstoptimes(self, params):
+    schedule = self.server.schedule
+    try:
+      trip = schedule.GetTrip(params.get('trip'))
+    except KeyError:
+       # if a non-existent trip is searched for, the return nothing
+      return
+    time_stops = trip.GetTimeStops()
+    stops = []
+    times = []
+    for arr,dep,stop in time_stops:
+      stops.append(StopToTuple(stop))
+      times.append(arr)
+    return [stops, times]
+
+  def handle_json_GET_tripshape(self, params):
+    schedule = self.server.schedule
+    try:
+      trip = schedule.GetTrip(params.get('trip'))
+    except KeyError:
+       # if a non-existent trip is searched for, the return nothing
+      return
+    points = []
+    if trip.shape_id:
+      shape = schedule.GetShape(trip.shape_id)
+      for (lat, lon, dist) in shape.points:
+        points.append((lat, lon))
+    else:
+      time_stops = trip.GetTimeStops()
+      for arr,dep,stop in time_stops:
+        points.append((stop.stop_lat, stop.stop_lon))
+    return points
+
+  def handle_json_GET_neareststops(self, params):
+    """Return a list of the nearest 'limit' stops to 'lat', 'lon'"""
+    schedule = self.server.schedule
+    lat = float(params.get('lat'))
+    lon = float(params.get('lon'))
+    limit = int(params.get('limit'))
+    stops = schedule.GetNearestStops(lat=lat, lon=lon, n=limit)
+    return [StopToTuple(s) for s in stops]
+
+  def handle_json_GET_boundboxstops(self, params):
+    """Return a list of up to 'limit' stops within bounding box with 'n','e'
+    and 's','w' in the NE and SW corners. Does not handle boxes crossing
+    longitude line 180."""
+    schedule = self.server.schedule
+    n = float(params.get('n'))
+    e = float(params.get('e'))
+    s = float(params.get('s'))
+    w = float(params.get('w'))
+    limit = int(params.get('limit'))
+    stops = schedule.GetStopsInBoundingBox(north=n, east=e, south=s, west=w, n=limit)
+    return [StopToTuple(s) for s in stops]
+
+  def handle_json_GET_stops(self, params):
+    schedule = self.server.schedule
+    return [StopToTuple(s) for s in schedule.GetStopList()]
+
+  def handle_json_GET_stopsearch(self, params):
+    schedule = self.server.schedule
+    query = params.get('q', None).lower()
+    matches = []
+    for s in schedule.GetStopList():
+      if s.stop_id.lower().find(query) != -1 or s.stop_name.lower().find(query) != -1:
+        matches.append(StopToTuple(s))
+    return matches
+
+  def handle_json_GET_stop(self, params):
+    schedule = self.server.schedule
+    query = params.get('stop_id', None).lower()
+    for s in schedule.GetStopList():
+      if s.stop_id.lower() == query:
+        return StopToTuple(s)
+    return []
+
+  def handle_json_GET_stoptrips(self, params):
+    """Given a stop_id and time in seconds since midnight return the next
+    trips to visit the stop."""
+    schedule = self.server.schedule
+    stop = schedule.GetStop(params.get('stop', None))
+    time = int(params.get('time', 0))
+    time_trips = stop.GetStopTimeTrips(schedule)
+    time_trips.sort()  # OPT: use bisect.insort to make this O(N*ln(N)) -> O(N)
+    # Keep the first 15 after param 'time'.
+    # Need make a tuple to find correct bisect point
+    time_trips = time_trips[bisect.bisect_left(time_trips, (time, 0)):]
+    time_trips = time_trips[:15]
+    # TODO: combine times for a route to show next 2 departure times
+    result = []
+    for time, (trip, index), tp in time_trips:
+      headsign = None
+      # Find the most recent headsign from the StopTime objects
+      for stoptime in trip.GetStopTimes()[index::-1]:
+        if stoptime.stop_headsign:
+          headsign = stoptime.stop_headsign
+          break
+      # If stop_headsign isn't found, look for a trip_headsign
+      if not headsign:
+        headsign = trip.trip_headsign
+      route = schedule.GetRoute(trip.route_id)
+      trip_name = ''
+      if route.route_short_name:
+        trip_name += route.route_short_name
+      if route.route_long_name:
+        if len(trip_name):
+          trip_name += " - "
+        trip_name += route.route_long_name
+      if headsign:
+        trip_name += " (Direction: %s)" % headsign
+
+      result.append((time, (trip.trip_id, trip_name, trip.service_id), tp))
+    return result
+
+  def handle_GET_ttablegraph(self,params):
+    """Draw a Marey graph in SVG for a pattern (collection of trips in a route
+    that visit the same sequence of stops)."""
+    schedule = self.server.schedule
+    marey = MareyGraph()
+    trip = schedule.GetTrip(params.get('trip', None))
+    route = schedule.GetRoute(trip.route_id)
+    height = int(params.get('height', 300))
+
+    if not route:
+      print 'no such route'
+      self.send_error(404)
+      return
+
+    pattern_id_trip_dict = route.GetPatternIdTripDict()
+    pattern_id = trip.pattern_id
+    if pattern_id not in pattern_id_trip_dict:
+      print 'no pattern %s found in %s' % (pattern_id, pattern_id_trip_dict.keys())
+      self.send_error(404)
+      return
+    triplist = pattern_id_trip_dict[pattern_id]
+
+    pattern_start_time = min((t.GetStartTime() for t in triplist))
+    pattern_end_time = max((t.GetEndTime() for t in triplist))
+
+    marey.SetSpan(pattern_start_time,pattern_end_time)
+    marey.Draw(triplist[0].GetPattern(), triplist, height)
+
+    content = marey.Draw()
+
+    self.send_response(200)
+    self.send_header('Content-Type', 'image/svg+xml')
+    self.send_header('Content-Length', str(len(content)))
+    self.end_headers()
+    self.wfile.write(content)
+
+
+def FindPy2ExeBase():
+  """If this is running in py2exe return the install directory else return
+  None"""
+  # py2exe puts gtfsscheduleviewer in library.zip. For py2exe setup.py is
+  # configured to put the data next to library.zip.
+  windows_ending = gtfsscheduleviewer.__file__.find('\\library.zip\\')
+  if windows_ending != -1:
+    return transitfeed.__file__[:windows_ending]
+  else:
+    return None
+
+
+def FindDefaultFileDir():
+  """Return the path of the directory containing the static files. By default
+  the directory is called 'files'. The location depends on where setup.py put
+  it."""
+  base = FindPy2ExeBase()
+  if base:
+    return os.path.join(base, 'schedule_viewer_files')
+  else:
+    # For all other distributions 'files' is in the gtfsscheduleviewer
+    # directory.
+    base = os.path.dirname(gtfsscheduleviewer.__file__)  # Strip __init__.py
+    return os.path.join(base, 'files')
+
+
+def GetDefaultKeyFilePath():
+  """In py2exe return absolute path of file in the base directory and in all
+  other distributions return relative path 'key.txt'"""
+  windows_base = FindPy2ExeBase()
+  if windows_base:
+    return os.path.join(windows_base, 'key.txt')
+  else:
+    return 'key.txt'
+
+
+def main(RequestHandlerClass = ScheduleRequestHandler):
+  usage = \
+'''%prog [options] [<input GTFS.zip>]
+
+Runs a webserver that lets you explore a <input GTFS.zip> in your browser.
+
+If <input GTFS.zip> is omited the filename is read from the console. Dragging
+a file into the console may enter the filename.
+'''
+  parser = util.OptionParserLongError(
+      usage=usage, version='%prog '+transitfeed.__version__)
+  parser.add_option('--feed_filename', '--feed', dest='feed_filename',
+                    help='file name of feed to load')
+  parser.add_option('--key', dest='key',
+                    help='Google Maps API key or the name '
+                    'of a text file that contains an API key')
+  parser.add_option('--host', dest='host', help='Host name of Google Maps')
+  parser.add_option('--port', dest='port', type='int',
+                    help='port on which to listen')
+  parser.add_option('--file_dir', dest='file_dir',
+                    help='directory containing static files')
+  parser.add_option('-n', '--noprompt', action='store_false',
+                    dest='manual_entry',
+                    help='disable interactive prompts')
+  parser.set_defaults(port=8765,
+                      host='maps.google.com',
+                      file_dir=FindDefaultFileDir(),
+                      manual_entry=True)
+  (options, args) = parser.parse_args()
+
+  if not os.path.isfile(os.path.join(options.file_dir, 'index.html')):
+    print "Can't find index.html with --file_dir=%s" % options.file_dir
+    exit(1)
+
+  if not options.feed_filename and len(args) == 1:
+    options.feed_filename = args[0]
+
+  if not options.feed_filename and options.manual_entry:
+    options.feed_filename = raw_input('Enter Feed Location: ').strip('"')
+
+  default_key_file = GetDefaultKeyFilePath()
+  if not options.key and os.path.isfile(default_key_file):
+    options.key = open(default_key_file).read().strip()
+
+  if options.key and os.path.isfile(options.key):
+    options.key = open(options.key).read().strip()
+
+  schedule = transitfeed.Schedule(problem_reporter=transitfeed.ProblemReporter())
+  print 'Loading data from feed "%s"...' % options.feed_filename
+  print '(this may take a few minutes for larger cities)'
+  schedule.Load(options.feed_filename)
+
+  server = StoppableHTTPServer(server_address=('', options.port),
+                               RequestHandlerClass=RequestHandlerClass)
+  server.key = options.key
+  server.schedule = schedule
+  server.file_dir = options.file_dir
+  server.host = options.host
+  server.feed_path = options.feed_filename
+
+  print ("To view, point your browser at http://localhost:%d/" %
+         (server.server_port))
+  server.serve_forever()
+
+
+if __name__ == '__main__':
+  main()
+