Upgrade origin-src to google transit feed 1.2.6
[bus.git] / origin-src / transitfeed-1.2.6 / shape_importer.py
blob:a/origin-src/transitfeed-1.2.6/shape_importer.py -> blob:b/origin-src/transitfeed-1.2.6/shape_importer.py
--- a/origin-src/transitfeed-1.2.6/shape_importer.py
+++ b/origin-src/transitfeed-1.2.6/shape_importer.py
@@ -1,1 +1,291 @@
-
+#!/usr/bin/python2.4
+#
+# Copyright 2007 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A utility program to help add shapes to an existing GTFS feed.
+
+Requires the ogr python package.
+"""
+
+__author__ = 'chris.harrelson.code@gmail.com (Chris Harrelson)'
+
+import csv
+import glob
+import ogr
+import os
+import shutil
+import sys
+import tempfile
+import transitfeed
+from transitfeed import shapelib
+from transitfeed import util
+import zipfile
+
+
+class ShapeImporterError(Exception):
+  pass
+
+
+def PrintColumns(shapefile):
+  """
+  Print the columns of layer 0 of the shapefile to the screen.
+  """
+  ds = ogr.Open(shapefile)
+  layer = ds.GetLayer(0)
+  if len(layer) == 0:
+    raise ShapeImporterError("Layer 0 has no elements!")
+
+  feature = layer.GetFeature(0)
+  print "%d features" % feature.GetFieldCount()
+  for j in range(0, feature.GetFieldCount()):
+    print '--' + feature.GetFieldDefnRef(j).GetName() + \
+          ': ' + feature.GetFieldAsString(j)
+
+
+def AddShapefile(shapefile, graph, key_cols):
+  """
+  Adds shapes found in the given shape filename to the given polyline
+  graph object.
+  """
+  ds = ogr.Open(shapefile)
+  layer = ds.GetLayer(0)
+
+  for i in range(0, len(layer)):
+    feature = layer.GetFeature(i)
+
+    geometry = feature.GetGeometryRef()
+
+    if key_cols:
+      key_list = []
+      for col in key_cols:
+        key_list.append(str(feature.GetField(col)))
+      shape_id = '-'.join(key_list)
+    else:
+      shape_id = '%s-%d' % (shapefile, i)
+
+    poly = shapelib.Poly(name=shape_id)
+    for j in range(0, geometry.GetPointCount()):
+      (lat, lng) = (round(geometry.GetY(j), 15), round(geometry.GetX(j), 15))
+      poly.AddPoint(shapelib.Point.FromLatLng(lat, lng))
+    graph.AddPoly(poly)
+
+  return graph
+
+
+def GetMatchingShape(pattern_poly, trip, matches, max_distance, verbosity=0):
+  """
+  Tries to find a matching shape for the given pattern Poly object,
+  trip, and set of possibly matching Polys from which to choose a match.
+  """
+  if len(matches) == 0:
+    print ('No matching shape found within max-distance %d for trip %s '
+           % (max_distance, trip.trip_id))
+    return None
+
+  if verbosity >= 1:
+    for match in matches:
+      print "match: size %d" % match.GetNumPoints()
+  scores = [(pattern_poly.GreedyPolyMatchDist(match), match)
+            for match in matches]
+
+  scores.sort()
+
+  if scores[0][0] > max_distance:
+    print ('No matching shape found within max-distance %d for trip %s '
+           '(min score was %f)'
+           % (max_distance, trip.trip_id, scores[0][0]))
+    return None
+
+  return scores[0][1]
+
+def AddExtraShapes(extra_shapes_txt, graph):
+  """
+  Add extra shapes into our input set by parsing them out of a GTFS-formatted
+  shapes.txt file.  Useful for manually adding lines to a shape file, since it's
+  a pain to edit .shp files.
+  """
+
+  print "Adding extra shapes from %s" % extra_shapes_txt
+  try:
+    tmpdir = tempfile.mkdtemp()
+    shutil.copy(extra_shapes_txt, os.path.join(tmpdir, 'shapes.txt'))
+    loader = transitfeed.ShapeLoader(tmpdir)
+    schedule = loader.Load()
+    for shape in schedule.GetShapeList():
+      print "Adding extra shape: %s" % shape.shape_id
+      graph.AddPoly(ShapeToPoly(shape))
+  finally:
+    if tmpdir:
+      shutil.rmtree(tmpdir)
+
+
+# Note: this method lives here to avoid cross-dependencies between
+# shapelib and transitfeed.
+def ShapeToPoly(shape):
+  poly = shapelib.Poly(name=shape.shape_id)
+  for lat, lng, distance in shape.points:
+    point = shapelib.Point.FromLatLng(round(lat, 15), round(lng, 15))
+    poly.AddPoint(point)
+  return poly
+
+
+def ValidateArgs(options_parser, options, args):
+  if not (args and options.source_gtfs and options.dest_gtfs):
+    options_parser.error("You must specify a source and dest GTFS file, "
+                         "and at least one source shapefile")
+
+
+def DefineOptions():
+  usage = \
+"""%prog [options] --source_gtfs=<input GTFS.zip> --dest_gtfs=<output GTFS.zip>\
+ <input.shp> [<input.shp>...]
+
+Try to match shapes in one or more SHP files to trips in a GTFS file."""
+  options_parser = util.OptionParserLongError(
+      usage=usage, version='%prog '+transitfeed.__version__)
+  options_parser.add_option("--print_columns",
+                            action="store_true",
+                            default=False,
+                            dest="print_columns",
+                            help="Print column names in shapefile DBF and exit")
+  options_parser.add_option("--keycols",
+                            default="",
+                            dest="keycols",
+                            help="Comma-separated list of the column names used"
+                                 "to index shape ids")
+  options_parser.add_option("--max_distance",
+                            type="int",
+                            default=150,
+                            dest="max_distance",
+                            help="Max distance from a shape to which to match")
+  options_parser.add_option("--source_gtfs",
+                            default="",
+                            dest="source_gtfs",
+                            metavar="FILE",
+                            help="Read input GTFS from FILE")
+  options_parser.add_option("--dest_gtfs",
+                            default="",
+                            dest="dest_gtfs",
+                            metavar="FILE",
+                            help="Write output GTFS with shapes to FILE")
+  options_parser.add_option("--extra_shapes",
+                            default="",
+                            dest="extra_shapes",
+                            metavar="FILE",
+                            help="Extra shapes.txt (CSV) formatted file")
+  options_parser.add_option("--verbosity",
+                            type="int",
+                            default=0,
+                            dest="verbosity",
+                            help="Verbosity level. Higher is more verbose")
+  return options_parser
+
+
+def main(key_cols):
+  print 'Parsing shapefile(s)...'
+  graph = shapelib.PolyGraph()
+  for arg in args:
+    print '  ' + arg
+    AddShapefile(arg, graph, key_cols)
+
+  if options.extra_shapes:
+    AddExtraShapes(options.extra_shapes, graph)
+
+  print 'Loading GTFS from %s...' % options.source_gtfs
+  schedule = transitfeed.Loader(options.source_gtfs).Load()
+  shape_count = 0
+  pattern_count = 0
+
+  verbosity = options.verbosity
+
+  print 'Matching shapes to trips...'
+  for route in schedule.GetRouteList():
+    print 'Processing route', route.route_short_name
+    patterns = route.GetPatternIdTripDict()
+    for pattern_id, trips in patterns.iteritems():
+      pattern_count += 1
+      pattern = trips[0].GetPattern()
+
+      poly_points = [shapelib.Point.FromLatLng(p.stop_lat, p.stop_lon)
+                     for p in pattern]
+      if verbosity >= 2:
+        print "\npattern %d, %d points:" % (pattern_id, len(poly_points))
+        for i, (stop, point) in enumerate(zip(pattern, poly_points)):
+          print "Stop %d '%s': %s" % (i + 1, stop.stop_name, point.ToLatLng())
+
+      # First, try to find polys that run all the way from
+      # the start of the trip to the end.
+      matches = graph.FindMatchingPolys(poly_points[0], poly_points[-1],
+                                        options.max_distance)
+      if not matches:
+        # Try to find a path through the graph, joining
+        # multiple edges to find a path that covers all the
+        # points in the trip.  Some shape files are structured
+        # this way, with a polyline for each segment between
+        # stations instead of a polyline covering an entire line.
+        shortest_path = graph.FindShortestMultiPointPath(poly_points,
+                                                         options.max_distance,
+                                                         verbosity=verbosity)
+        if shortest_path:
+          matches = [shortest_path]
+        else:
+          matches = []
+
+      pattern_poly = shapelib.Poly(poly_points)
+      shape_match = GetMatchingShape(pattern_poly, trips[0],
+                                     matches, options.max_distance,
+                                     verbosity=verbosity)
+      if shape_match:
+        shape_count += 1
+        # Rename shape for readability.
+        shape_match = shapelib.Poly(points=shape_match.GetPoints(),
+                                           name="shape_%d" % shape_count)
+        for trip in trips:
+          try:
+            shape = schedule.GetShape(shape_match.GetName())
+          except KeyError:
+            shape = transitfeed.Shape(shape_match.GetName())
+            for point in shape_match.GetPoints():
+              (lat, lng) = point.ToLatLng()
+              shape.AddPoint(lat, lng)
+            schedule.AddShapeObject(shape)
+          trip.shape_id = shape.shape_id
+
+  print "Matched %d shapes out of %d patterns" % (shape_count, pattern_count)
+  schedule.WriteGoogleTransitFeed(options.dest_gtfs)
+
+
+if __name__ == '__main__':
+  # Import psyco if available for better performance.
+  try:
+    import psyco
+    psyco.full()
+  except ImportError:
+    pass
+
+  options_parser = DefineOptions()
+  (options, args) = options_parser.parse_args()
+
+  ValidateArgs(options_parser, options, args)
+
+  if options.print_columns:
+    for arg in args:
+      PrintColumns(arg)
+    sys.exit(0)
+
+  key_cols = options.keycols.split(',')
+
+  main(key_cols)
+