• Facebook
  • Twitter
  • Reddit
  • StumbleUpon
  • Digg
  • email

from sqlalchemy.orm.interfaces import AttributeExtension
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.types import TypeEngine
from sqlalchemy.sql import expression
 
# Python datatypes
 
class GisElement(object):
    """Represents a geometry value."""
 
    @property
    def wkt(self):
        return func.AsText(literal(self, Geometry))
 
    @property
    def wkb(self):
        return func.AsBinary(literal(self, Geometry))
 
    def __str__(self):
        return self.desc
 
    def __repr__(self):
        return "<%s at 0x%x; %r>" % (self.__class__.__name__, id(self), self.desc)
 
class PersistentGisElement(GisElement):
    """Represents a Geometry value as loaded from the database."""
 
    def __init__(self, desc):
        self.desc = desc
 
class TextualGisElement(GisElement, expression.Function):
    """Represents a Geometry value as expressed within application code; i.e. in wkt format.
 
    Extends expression.Function so that the value is interpreted as 
    GeomFromText(value) in a SQL expression context.
 
    """
 
    def __init__(self, desc, srid=-1):
        assert isinstance(desc, basestring)
        self.desc = desc
        expression.Function.__init__(self, "GeomFromText", desc, srid)
 
 
# SQL datatypes.
 
class Geometry(TypeEngine):
    """Base PostGIS Geometry column type.
 
    Converts bind/result values to/from a PersistentGisElement.
 
    """
 
    name = 'GEOMETRY'
 
    def __init__(self, dimension=None, srid=-1):
        self.dimension = dimension
        self.srid = srid
 
    def bind_processor(self, dialect):
        def process(value):
            if value is not None:
                return value.desc
            else:
                return value
        return process
 
    def result_processor(self, dialect, coltype):
        def process(value):
            if value is not None:
                return PersistentGisElement(value)
            else:
                return value
        return process
 
# other datatypes can be added as needed, which 
# currently only affect DDL statements.
 
class Point(Geometry):
    name = 'POINT'
 
class Curve(Geometry):
    name = 'CURVE'
 
class LineString(Curve):
    name = 'LINESTRING'
 
# ... etc.
 
 
# DDL integration
 
class GISDDL(object):
    """A DDL extension which integrates SQLAlchemy table create/drop 
    methods with PostGis' AddGeometryColumn/DropGeometryColumn functions.
 
    Usage::
 
        sometable = Table('sometable', metadata, ...)
 
        GISDDL(sometable)
 
        sometable.create()
 
    """
 
    def __init__(self, table):
        for event in ('before-create', 'after-create', 'before-drop', 'after-drop'):
            table.ddl_listeners[event].append(self)
        self._stack = []
 
    def __call__(self, event, table, bind):
        if event in ('before-create', 'before-drop'):
            regular_cols = [c for c in table.c if not isinstance(c.type, Geometry)]
            gis_cols = set(table.c).difference(regular_cols)
            self._stack.append(table.c)
            table._columns = expression.ColumnCollection(*regular_cols)
 
            if event == 'before-drop':
                for c in gis_cols:
                    bind.execute(select([func.DropGeometryColumn('public', table.name, c.name)], autocommit=True))
 
        elif event == 'after-create':
            table._columns = self._stack.pop()
 
            for c in table.c:
                if isinstance(c.type, Geometry):
                    bind.execute(select([func.AddGeometryColumn(table.name, c.name, c.type.srid, c.type.name, c.type.dimension)], autocommit=True))
        elif event == 'after-drop':
            table._columns = self._stack.pop()
 
# ORM integration
 
def _to_postgis(value):
    """Interpret a value as a GIS-compatible construct."""
 
    if hasattr(value, '__clause_element__'):
        return value.__clause_element__()
    elif isinstance(value, (expression.ClauseElement, GisElement)):
        return value
    elif isinstance(value, basestring):
        return TextualGisElement(value)
    elif value is None:
        return None
    else:
        raise Exception("Invalid type")
 
 
class GisAttribute(AttributeExtension):
    """Intercepts 'set' events on a mapped instance attribute and 
    converts the incoming value to a GIS expression.
 
    """
 
    def set(self, state, value, oldvalue, initiator):
        return _to_postgis(value)
 
class GisComparator(ColumnProperty.ColumnComparator):
    """Intercepts standard Column operators on mapped class attributes
    and overrides their behavior.
 
    """
 
    # override the __eq__() operator
    def __eq__(self, other):
        return self.__clause_element__().op('~=')(_to_postgis(other))
 
    # add a custom operator
    def intersects(self, other):
        return self.__clause_element__().op('&&')(_to_postgis(other))
 
    # any number of GIS operators can be overridden/added here
    # using the techniques above.
 
 
def GISColumn(*args, **kw):
    """Define a declarative column property with GIS behavior.
 
    This just produces orm.column_property() with the appropriate
    extension and comparator_factory arguments.  The given arguments
    are passed through to Column.  The declarative module extracts
    the Column for inclusion in the mapped table.
 
    """
    return column_property(
                Column(*args, **kw), 
                extension=GisAttribute(), 
                comparator_factory=GisComparator
            )
 
# illustrate usage
if __name__ == '__main__':
    from sqlalchemy import (create_engine, MetaData, Column, Integer, String,
        func, literal, select)
    from sqlalchemy.orm import sessionmaker, column_property
    from sqlalchemy.ext.declarative import declarative_base
 
    engine = create_engine('postgresql://scott:tiger@localhost/gistest', echo=True)
    metadata = MetaData(engine)
    Base = declarative_base(metadata=metadata)
 
    class Road(Base):
        __tablename__ = 'roads'
 
        road_id = Column(Integer, primary_key=True)
        road_name = Column(String)
        road_geom = GISColumn(Geometry(2))
 
    # enable the DDL extension, which allows CREATE/DROP operations
    # to work correctly.  This is not needed if working with externally
    # defined tables.    
    GISDDL(Road.__table__)
 
    metadata.drop_all()
    metadata.create_all()
 
    session = sessionmaker(bind=engine)()
 
    # Add objects.  We can use strings...
    session.add_all([
        Road(road_name='Jeff Rd', road_geom='LINESTRING(191232 243118,191108 243242)'),
        Road(road_name='Geordie Rd', road_geom='LINESTRING(189141 244158,189265 244817)'),
        Road(road_name='Paul St', road_geom='LINESTRING(192783 228138,192612 229814)'),
        Road(road_name='Graeme Ave', road_geom='LINESTRING(189412 252431,189631 259122)'),
        Road(road_name='Phil Tce', road_geom='LINESTRING(190131 224148,190871 228134)'),
    ])
 
    # or use an explicit TextualGisElement (similar to saying func.GeomFromText())
    r = Road(road_name='Dave Cres', road_geom=TextualGisElement('LINESTRING(198231 263418,198213 268322)', -1))
    session.add(r)
 
    # pre flush, the TextualGisElement represents the string we sent.
    assert str(r.road_geom) == 'LINESTRING(198231 263418,198213 268322)'
    assert session.scalar(r.road_geom.wkt) == 'LINESTRING(198231 263418,198213 268322)'
 
    session.commit()
 
    # after flush and/or commit, all the TextualGisElements become PersistentGisElements.
    assert str(r.road_geom) == "01020000000200000000000000B832084100000000E813104100000000283208410000000088601041"
 
    r1 = session.query(Road).filter(Road.road_name=='Graeme Ave').one()
 
    # illustrate the overridden __eq__() operator.
 
    # strings come in as TextualGisElements
    r2 = session.query(Road).filter(Road.road_geom == 'LINESTRING(189412 252431,189631 259122)').one()
 
    # PersistentGisElements work directly
    r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one()
 
    assert r1 is r2 is r3
 
    # illustrate the "intersects" operator
    print session.query(Road).filter(Road.road_geom.intersects(r1.road_geom)).all()
 
    # illustrate usage of the "wkt" accessor. this requires a DB
    # execution to call the AsText() function so we keep this explicit.
    assert session.scalar(r1.road_geom.wkt) == 'LINESTRING(189412 252431,189631 259122)'
 
    session.rollback()
 
    metadata.drop_all()