|
#!/usr/bin/python2.5 |
|
|
|
# Copyright (C) 2007 Google Inc. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
""" |
|
An example application that uses the transitfeed module. |
|
|
|
You must provide a Google Maps API key. |
|
""" |
|
|
|
|
|
import BaseHTTPServer, sys, urlparse |
|
import bisect |
|
from gtfsscheduleviewer.marey_graph import MareyGraph |
|
import gtfsscheduleviewer |
|
import mimetypes |
|
import os.path |
|
import re |
|
import signal |
|
import simplejson |
|
import socket |
|
import time |
|
import transitfeed |
|
from transitfeed import util |
|
import urllib |
|
|
|
|
|
# By default Windows kills Python with Ctrl+Break. Instead make Ctrl+Break |
|
# raise a KeyboardInterrupt. |
|
if hasattr(signal, 'SIGBREAK'): |
|
signal.signal(signal.SIGBREAK, signal.default_int_handler) |
|
|
|
|
|
mimetypes.add_type('text/plain', '.vbs') |
|
|
|
|
|
class ResultEncoder(simplejson.JSONEncoder): |
|
def default(self, obj): |
|
try: |
|
iterable = iter(obj) |
|
except TypeError: |
|
pass |
|
else: |
|
return list(iterable) |
|
return simplejson.JSONEncoder.default(self, obj) |
|
|
|
# Code taken from |
|
# http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/425210/index_txt |
|
# An alternate approach is shown at |
|
# http://mail.python.org/pipermail/python-list/2003-July/212751.html |
|
# but it requires multiple threads. A sqlite object can only be used from one |
|
# thread. |
|
class StoppableHTTPServer(BaseHTTPServer.HTTPServer): |
|
def server_bind(self): |
|
BaseHTTPServer.HTTPServer.server_bind(self) |
|
self.socket.settimeout(1) |
|
self._run = True |
|
|
|
def get_request(self): |
|
while self._run: |
|
try: |
|
sock, addr = self.socket.accept() |
|
sock.settimeout(None) |
|
return (sock, addr) |
|
except socket.timeout: |
|
pass |
|
|
|
def stop(self): |
|
self._run = False |
|
|
|
def serve(self): |
|
while self._run: |
|
self.handle_request() |
|
|
|
|
|
def StopToTuple(stop): |
|
"""Return tuple as expected by javascript function addStopMarkerFromList""" |
|
return (stop.stop_id, stop.stop_name, float(stop.stop_lat), |
|
float(stop.stop_lon), stop.location_type) |
|
|
|
|
|
class ScheduleRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): |
|
def do_GET(self): |
|
scheme, host, path, x, params, fragment = urlparse.urlparse(self.path) |
|
parsed_params = {} |
|
for k in params.split('&'): |
|
k = urllib.unquote(k) |
|
if '=' in k: |
|
k, v = k.split('=', 1) |
|
parsed_params[k] = unicode(v, 'utf8') |
|
else: |
|
parsed_params[k] = '' |
|
|
|
if path == '/': |
|
return self.handle_GET_home() |
|
|
|
m = re.match(r'/json/([a-z]{1,64})', path) |
|
if m: |
|
handler_name = 'handle_json_GET_%s' % m.group(1) |
|
handler = getattr(self, handler_name, None) |
|
if callable(handler): |
|
return self.handle_json_wrapper_GET(handler, parsed_params) |
|
|
|
# Restrict allowable file names to prevent relative path attacks etc |
|
m = re.match(r'/file/([a-z0-9_-]{1,64}\.?[a-z0-9_-]{1,64})$', path) |
|
if m and m.group(1): |
|
try: |
|
f, mime_type = self.OpenFile(m.group(1)) |
|
return self.handle_static_file_GET(f, mime_type) |
|
except IOError, e: |
|
print "Error: unable to open %s" % m.group(1) |
|
# Ignore and treat as 404 |
|
|
|
m = re.match(r'/([a-z]{1,64})', path) |
|
if m: |
|
handler_name = 'handle_GET_%s' % m.group(1) |
|
handler = getattr(self, handler_name, None) |
|
if callable(handler): |
|
return handler(parsed_params) |
|
|
|
return self.handle_GET_default(parsed_params, path) |
|
|
|
def OpenFile(self, filename): |
|
"""Try to open filename in the static files directory of this server. |
|
Return a tuple (file object, string mime_type) or raise an exception.""" |
|
(mime_type, encoding) = mimetypes.guess_type(filename) |
|
assert mime_type |
|
# A crude guess of when we should use binary mode. Without it non-unix |
|
# platforms may corrupt binary files. |
|
if mime_type.startswith('text/'): |
|
mode = 'r' |
|
else: |
|
mode = 'rb' |
|
return open(os.path.join(self.server.file_dir, filename), mode), mime_type |
|
|
|
def handle_GET_default(self, parsed_params, path): |
|
self.send_error(404) |
|
|
|
def handle_static_file_GET(self, fh, mime_type): |
|
content = fh.read() |
|
self.send_response(200) |
|
self.send_header('Content-Type', mime_type) |
|
self.send_header('Content-Length', str(len(content))) |
|
self.end_headers() |
|
self.wfile.write(content) |
|
|
|
def AllowEditMode(self): |
|
return False |
|
|
|
def handle_GET_home(self): |
|
schedule = self.server.schedule |
|
(min_lat, min_lon, max_lat, max_lon) = schedule.GetStopBoundingBox() |
|
forbid_editing = ('true', 'false')[self.AllowEditMode()] |
|
|
|
agency = ', '.join(a.agency_name for a in schedule.GetAgencyList()).encode('utf-8') |
|
|
|
key = self.server.key |
|
host = self.server.host |
|
|
|
# A very simple template system. For a fixed set of values replace [xxx] |
|
# with the value of local variable xxx |
|
f, _ = self.OpenFile('index.html') |
|
content = f.read() |
|
for v in ('agency', 'min_lat', 'min_lon', 'max_lat', 'max_lon', 'key', |
|
'host', 'forbid_editing'): |
|
content = content.replace('[%s]' % v, str(locals()[v])) |
|
|
|
self.send_response(200) |
|
self.send_header('Content-Type', 'text/html') |
|
self.send_header('Content-Length', str(len(content))) |
|
self.end_headers() |
|
self.wfile.write(content) |
|
|
|
def handle_json_GET_routepatterns(self, params): |
|
"""Given a route_id generate a list of patterns of the route. For each |
|
pattern include some basic information and a few sample trips.""" |
|
schedule = self.server.schedule |
|
route = schedule.GetRoute(params.get('route', None)) |
|
if not route: |
|
self.send_error(404) |
|
return |
|
time = int(params.get('time', 0)) |
|
sample_size = 10 # For each pattern return the start time for this many trips |
|
|
|
pattern_id_trip_dict = route.GetPatternIdTripDict() |
|
patterns = [] |
|
|
|
for pattern_id, trips in pattern_id_trip_dict.items(): |
|
time_stops = trips[0].GetTimeStops() |
|
if not time_stops: |
|
continue |
|
has_non_zero_trip_type = False; |
|
for trip in trips: |
|
if trip['trip_type'] and trip['trip_type'] != '0': |
|
has_non_zero_trip_type = True |
|
name = u'%s to %s, %d stops' % (time_stops[0][2].stop_name, time_stops[-1][2].stop_name, len(time_stops)) |
|
transitfeed.SortListOfTripByTime(trips) |
|
|
|
num_trips = len(trips) |
|
if num_trips <= sample_size: |
|
start_sample_index = 0 |
|
num_after_sample = 0 |
|
else: |
|
# Will return sample_size trips that start after the 'time' param. |
|
|
|
# Linear search because I couldn't find a built-in way to do a binary |
|
# search with a custom key. |
|
start_sample_index = len(trips) |
|
for i, trip in enumerate(trips): |
|
if trip.GetStartTime() >= time: |
|
start_sample_index = i |
|
break |
|
|
|
num_after_sample = num_trips - (start_sample_index + sample_size) |
|
if num_after_sample < 0: |
|
# Less than sample_size trips start after 'time' so return all the |
|
# last sample_size trips. |
|
num_after_sample = 0 |
|
start_sample_index = num_trips - sample_size |
|
|
|
sample = [] |
|
for t in trips[start_sample_index:start_sample_index + sample_size]: |
|
sample.append( (t.GetStartTime(), t.trip_id) ) |
|
|
|
patterns.append((name, pattern_id, start_sample_index, sample, |
|
num_after_sample, (0,1)[has_non_zero_trip_type])) |
|
|
|
patterns.sort() |
|
return patterns |
|
|
|
def handle_json_wrapper_GET(self, handler, parsed_params): |
|
"""Call handler and output the return value in JSON.""" |
|
schedule = self.server.schedule |
|
result = handler(parsed_params) |
|
content = ResultEncoder().encode(result) |
|
self.send_response(200) |
|
self.send_header('Content-Type', 'text/plain') |
|
self.send_header('Content-Length', str(len(content))) |
|
self.end_headers() |
|
self.wfile.write(content) |
|
|
|
def handle_json_GET_routes(self, params): |
|
"""Return a list of all routes.""" |
|
schedule = self.server.schedule |
|
result = [] |
|
for r in schedule.GetRouteList(): |
|
result.append( (r.route_id, r.route_short_name, r.route_long_name) ) |
|
result.sort(key = lambda x: x[1:3]) |
|
return result |
|
|
|
def handle_json_GET_routerow(self, params): |
|
schedule = self.server.schedule |
|
route = schedule.GetRoute(params.get('route', None)) |
|
return [transitfeed.Route._FIELD_NAMES, route.GetFieldValuesTuple()] |
|
|
|
def handle_json_GET_triprows(self, params): |
|
"""Return a list of rows from the feed file that are related to this |
|
trip.""" |
|
schedule = self.server.schedule |
|
try: |
|
trip = schedule.GetTrip(params.get('trip', None)) |
|
except KeyError: |
|
# if a non-existent trip is searched for, the return nothing |
|
return |
|
route = schedule.GetRoute(trip.route_id) |
|
trip_row = dict(trip.iteritems()) |
|
route_row = dict(route.iteritems()) |
|
return [['trips.txt', trip_row], ['routes.txt', route_row]] |
|
|
|
def handle_json_GET_tripstoptimes(self, params): |
|
schedule = self.server.schedule |
|
try: |
|
trip = schedule.GetTrip(params.get('trip')) |
|
except KeyError: |
|
# if a non-existent trip is searched for, the return nothing |
|
return |
|
time_stops = trip.GetTimeStops() |
|
stops = [] |
|
times = [] |
|
for arr,dep,stop in time_stops: |
|
stops.append(StopToTuple(stop)) |
|
times.append(arr) |
|
return [stops, times] |
|
|
|
def handle_json_GET_tripshape(self, params): |
|
schedule = self.server.schedule |
|
try: |
|
trip = schedule.GetTrip(params.get('trip')) |
|
except KeyError: |
|
# if a non-existent trip is searched for, the return nothing |
|
return |
|
points = [] |
|
if trip.shape_id: |
|
shape = schedule.GetShape(trip.shape_id) |
|
for (lat, lon, dist) in shape.points: |
|
points.append((lat, lon)) |
|
else: |
|
time_stops = trip.GetTimeStops() |
|
for arr,dep,stop in time_stops: |
|
points.append((stop.stop_lat, stop.stop_lon)) |
|
return points |
|
|
|
def handle_json_GET_neareststops(self, params): |
|
"""Return a list of the nearest 'limit' stops to 'lat', 'lon'""" |
|
schedule = self.server.schedule |
|
lat = float(params.get('lat')) |
|
lon = float(params.get('lon')) |
|
limit = int(params.get('limit')) |
|
stops = schedule.GetNearestStops(lat=lat, lon=lon, n=limit) |
|
return [StopToTuple(s) for s in stops] |
|
|
|
def handle_json_GET_boundboxstops(self, params): |
|
"""Return a list of up to 'limit' stops within bounding box with 'n','e' |
|
and 's','w' in the NE and SW corners. Does not handle boxes crossing |
|
longitude line 180.""" |
|
schedule = self.server.schedule |
|
n = float(params.get('n')) |
|
e = float(params.get('e')) |
|
s = float(params.get('s')) |
|
w = float(params.get('w')) |
|
limit = int(params.get('limit')) |
|
stops = schedule.GetStopsInBoundingBox(north=n, east=e, south=s, west=w, n=limit) |
|
return [StopToTuple(s) for s in stops] |
|
|
|
def handle_json_GET_stops(self, params): |
|
schedule = self.server.schedule |
|
return [StopToTuple(s) for s in schedule.GetStopList()] |
|
|
|
def handle_json_GET_stopsearch(self, params): |
|
schedule = self.server.schedule |
|
query = params.get('q', None).lower() |
|
matches = [] |
|
for s in schedule.GetStopList(): |
|
if s.stop_id.lower().find(query) != -1 or s.stop_name.lower().find(query) != -1: |
|
matches.append(StopToTuple(s)) |
|
return matches |
|
|
|
def handle_json_GET_stop(self, params): |
|
schedule = self.server.schedule |
|
query = params.get('stop_id', None).lower() |
|
for s in schedule.GetStopList(): |
|
if s.stop_id.lower() == query: |
|
return StopToTuple(s) |
|
return [] |
|
|
|
def handle_json_GET_stoptrips(self, params): |
|
"""Given a stop_id and time in seconds since midnight return the next |
|
trips to visit the stop.""" |
|
schedule = self.server.schedule |
|
stop = schedule.GetStop(params.get('stop', None)) |
|
time = int(params.get('time', 0)) |
|
time_trips = stop.GetStopTimeTrips(schedule) |
|
time_trips.sort() # OPT: use bisect.insort to make this O(N*ln(N)) -> O(N) |
|
# Keep the first 15 after param 'time'. |
|
# Need make a tuple to find correct bisect point |
|
time_trips = time_trips[bisect.bisect_left(time_trips, (time, 0)):] |
|
time_trips = time_trips[:15] |
|
# TODO: combine times for a route to show next 2 departure times |
|
result = [] |
|
for time, (trip, index), tp in time_trips: |
|
headsign = None |
|
# Find the most recent headsign from the StopTime objects |
|
for stoptime in trip.GetStopTimes()[index::-1]: |
|
if stoptime.stop_headsign: |
|
headsign = stoptime.stop_headsign |
|
break |
|
# If stop_headsign isn't found, look for a trip_headsign |
|
if not headsign: |
|
headsign = trip.trip_headsign |
|
route = schedule.GetRoute(trip.route_id) |
|
trip_name = '' |
|
if route.route_short_name: |
|
trip_name += route.route_short_name |
|
if route.route_long_name: |
|
if len(trip_name): |
|
trip_name += " - " |
|
trip_name += route.route_long_name |
|
if headsign: |
|
trip_name += " (Direction: %s)" % headsign |
|
|
|
result.append((time, (trip.trip_id, trip_name, trip.service_id), tp)) |
|
return result |
|
|
|
def handle_GET_ttablegraph(self,params): |
|
"""Draw a Marey graph in SVG for a pattern (collection of trips in a route |
|
that visit the same sequence of stops).""" |
|
schedule = self.server.schedule |
|
marey = MareyGraph() |
|
trip = schedule.GetTrip(params.get('trip', None)) |
|
route = schedule.GetRoute(trip.route_id) |
|
height = int(params.get('height', 300)) |
|
|
|
if not route: |
|
print 'no such route' |
|
self.send_error(404) |
|
return |
|
|
|
pattern_id_trip_dict = route.GetPatternIdTripDict() |
|
pattern_id = trip.pattern_id |
|
if pattern_id not in pattern_id_trip_dict: |
|
print 'no pattern %s found in %s' % (pattern_id, pattern_id_trip_dict.keys()) |
|
self.send_error(404) |
|
return |
|
triplist = pattern_id_trip_dict[pattern_id] |
|
|
|
pattern_start_time = min((t.GetStartTime() for t in triplist)) |
|
pattern_end_time = max((t.GetEndTime() for t in triplist)) |
|
|
|
marey.SetSpan(pattern_start_time,pattern_end_time) |
|
marey.Draw(triplist[0].GetPattern(), triplist, height) |
|
|
|
content = marey.Draw() |
|
|
|
self.send_response(200) |
|
self.send_header('Content-Type', 'image/svg+xml') |
|
self.send_header('Content-Length', str(len(content))) |
|
self.end_headers() |
|
self.wfile.write(content) |
|
|
|
|
|
def FindPy2ExeBase(): |
|
"""If this is running in py2exe return the install directory else return |
|
None""" |
|
# py2exe puts gtfsscheduleviewer in library.zip. For py2exe setup.py is |
|
# configured to put the data next to library.zip. |
|
windows_ending = gtfsscheduleviewer.__file__.find('\\library.zip\\') |
|
if windows_ending != -1: |
|
return transitfeed.__file__[:windows_ending] |
|
else: |
|
return None |
|
|
|
|
|
def FindDefaultFileDir(): |
|
"""Return the path of the directory containing the static files. By default |
|
the directory is called 'files'. The location depends on where setup.py put |
|
it.""" |
|
base = FindPy2ExeBase() |
|
if base: |
|
return os.path.join(base, 'schedule_viewer_files') |
|
else: |
|
# For all other distributions 'files' is in the gtfsscheduleviewer |
|
# directory. |
|
base = os.path.dirname(gtfsscheduleviewer.__file__) # Strip __init__.py |
|
return os.path.join(base, 'files') |
|
|
|
|
|
def GetDefaultKeyFilePath(): |
|
"""In py2exe return absolute path of file in the base directory and in all |
|
other distributions return relative path 'key.txt'""" |
|
windows_base = FindPy2ExeBase() |
|
if windows_base: |
|
return os.path.join(windows_base, 'key.txt') |
|
else: |
|
return 'key.txt' |
|
|
|
|
|
def main(RequestHandlerClass = ScheduleRequestHandler): |
|
usage = \ |
|
'''%prog [options] [<input GTFS.zip>] |
|
|
|
Runs a webserver that lets you explore a <input GTFS.zip> in your browser. |
|
|
|
If <input GTFS.zip> is omited the filename is read from the console. Dragging |
|
a file into the console may enter the filename. |
|
''' |
|
parser = util.OptionParserLongError( |
|
usage=usage, version='%prog '+transitfeed.__version__) |
|
parser.add_option('--feed_filename', '--feed', dest='feed_filename', |
|
help='file name of feed to load') |
|
parser.add_option('--key', dest='key', |
|
help='Google Maps API key or the name ' |
|
'of a text file that contains an API key') |
|
parser.add_option('--host', dest='host', help='Host name of Google Maps') |
|
parser.add_option('--port', dest='port', type='int', |
|
help='port on which to listen') |
|
parser.add_option('--file_dir', dest='file_dir', |
|
help='directory containing static files') |
|
parser.add_option('-n', '--noprompt', action='store_false', |
|
dest='manual_entry', |
|
help='disable interactive prompts') |
|
parser.set_defaults(port=8765, |
|
host='maps.google.com', |
|
file_dir=FindDefaultFileDir(), |
|
manual_entry=True) |
|
(options, args) = parser.parse_args() |
|
|
|
if not os.path.isfile(os.path.join(options.file_dir, 'index.html')): |
|
print "Can't find index.html with --file_dir=%s" % options.file_dir |
|
exit(1) |
|
|
|
if not options.feed_filename and len(args) == 1: |
|
options.feed_filename = args[0] |
|
|
|
if not options.feed_filename and options.manual_entry: |
|
options.feed_filename = raw_input('Enter Feed Location: ').strip('"') |
|
|
|
default_key_file = GetDefaultKeyFilePath() |
|
if not options.key and os.path.isfile(default_key_file): |
|
options.key = open(default_key_file).read().strip() |
|
|
|
if options.key and os.path.isfile(options.key): |
|
options.key = open(options.key).read().strip() |
|
|
|
schedule = transitfeed.Schedule(problem_reporter=transitfeed.ProblemReporter()) |
|
print 'Loading data from feed "%s"...' % options.feed_filename |
|
print '(this may take a few minutes for larger cities)' |
|
schedule.Load(options.feed_filename) |
|
|
|
server = StoppableHTTPServer(server_address=('', options.port), |
|
RequestHandlerClass=RequestHandlerClass) |
|
server.key = options.key |
|
server.schedule = schedule |
|
server.file_dir = options.file_dir |
|
server.host = options.host |
|
server.feed_path = options.feed_filename |
|
|
|
print ("To view, point your browser at http://localhost:%d/" % |
|
(server.server_port)) |
|
server.serve_forever() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|