Ruby on Rails | Screencasts | Download | Documentation | Weblog | Community | Source

root/tags/rel_1-0-0/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb

Revision 3236, 17.0 kB (checked in by bitsweat, 4 years ago)

r3375@asus: jeremy | 2005-12-07 20:36:13 -0800
Apply [3235] to stable. PostgreSQL: more robust sequence name discovery. Closes #3087.

Line 
1 require 'active_record/connection_adapters/abstract_adapter'
2
3 module ActiveRecord
4   class Base
5     # Establishes a connection to the database that's used by all Active Record objects
6     def self.postgresql_connection(config) # :nodoc:
7       require_library_or_gem 'postgres' unless self.class.const_defined?(:PGconn)
8
9       config = config.symbolize_keys
10       host     = config[:host]
11       port     = config[:port]     || 5432 unless host.nil?
12       username = config[:username].to_s
13       password = config[:password].to_s
14
15       min_messages = config[:min_messages]
16
17       if config.has_key?(:database)
18         database = config[:database]
19       else
20         raise ArgumentError, "No database specified. Missing argument: database."
21       end
22
23       pga = ConnectionAdapters::PostgreSQLAdapter.new(
24         PGconn.connect(host, port, "", "", database, username, password), logger, config
25       )
26
27       pga.schema_search_path = config[:schema_search_path] || config[:schema_order]
28
29       pga
30     end
31   end
32
33   module ConnectionAdapters
34     # The PostgreSQL adapter works both with the C-based (http://www.postgresql.jp/interfaces/ruby/) and the Ruby-base
35     # (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1145) drivers.
36     #
37     # Options:
38     #
39     # * <tt>:host</tt> -- Defaults to localhost
40     # * <tt>:port</tt> -- Defaults to 5432
41     # * <tt>:username</tt> -- Defaults to nothing
42     # * <tt>:password</tt> -- Defaults to nothing
43     # * <tt>:database</tt> -- The name of the database. No default, must be provided.
44     # * <tt>:schema_search_path</tt> -- An optional schema search path for the connection given as a string of comma-separated schema names.  This is backward-compatible with the :schema_order option.
45     # * <tt>:encoding</tt> -- An optional client encoding that is using in a SET client_encoding TO <encoding> call on connection.
46     # * <tt>:min_messages</tt> -- An optional client min messages that is using in a SET client_min_messages TO <min_messages> call on connection.
47     class PostgreSQLAdapter < AbstractAdapter
48       def adapter_name
49         'PostgreSQL'
50       end
51
52       def initialize(connection, logger, config = {})
53         super(connection, logger)
54         @config = config
55         configure_connection
56       end
57
58       # Is this connection alive and ready for queries?
59       def active?
60         if @connection.respond_to?(:status)
61           @connection.status == PGconn::CONNECTION_OK
62         else
63           @connection.query 'SELECT 1'
64           true
65         end
66       rescue PGError
67         false
68       end
69
70       # Close then reopen the connection.
71       def reconnect!
72         # TODO: postgres-pr doesn't have PGconn#reset.
73         if @connection.respond_to?(:reset)
74           @connection.reset
75           configure_connection
76         end
77       end
78
79       def native_database_types
80         {
81           :primary_key => "serial primary key",
82           :string      => { :name => "character varying", :limit => 255 },
83           :text        => { :name => "text" },
84           :integer     => { :name => "integer" },
85           :float       => { :name => "float" },
86           :datetime    => { :name => "timestamp" },
87           :timestamp   => { :name => "timestamp" },
88           :time        => { :name => "time" },
89           :date        => { :name => "date" },
90           :binary      => { :name => "bytea" },
91           :boolean     => { :name => "boolean" }
92         }
93       end
94      
95       def supports_migrations?
96         true
97       end     
98      
99
100       # QUOTING ==================================================
101
102       def quote(value, column = nil)
103         if value.kind_of?(String) && column && column.type == :binary
104           "'#{escape_bytea(value)}'"
105         else
106           super
107         end
108       end
109
110       def quote_column_name(name)
111         %("#{name}")
112       end
113
114
115       # DATABASE STATEMENTS ======================================
116
117       def select_all(sql, name = nil) #:nodoc:
118         select(sql, name)
119       end
120
121       def select_one(sql, name = nil) #:nodoc:
122         result = select(sql, name)
123         result.first if result
124       end
125
126       def insert(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil) #:nodoc:
127         execute(sql, name)
128         table = sql.split(" ", 4)[2]
129         id_value || last_insert_id(table, sequence_name || default_sequence_name(table, pk))
130       end
131
132       def query(sql, name = nil) #:nodoc:
133         log(sql, name) { @connection.query(sql) }
134       end
135
136       def execute(sql, name = nil) #:nodoc:
137         log(sql, name) { @connection.exec(sql) }
138       end
139
140       def update(sql, name = nil) #:nodoc:
141         execute(sql, name).cmdtuples
142       end
143
144       alias_method :delete, :update #:nodoc:
145
146
147       def begin_db_transaction #:nodoc:
148         execute "BEGIN"
149       end
150
151       def commit_db_transaction #:nodoc:
152         execute "COMMIT"
153       end
154      
155       def rollback_db_transaction #:nodoc:
156         execute "ROLLBACK"
157       end
158
159
160       # SCHEMA STATEMENTS ========================================
161
162       # Return the list of all tables in the schema search path.
163       def tables(name = nil) #:nodoc:
164         schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
165         query(<<-SQL, name).map { |row| row[0] }
166           SELECT tablename
167             FROM pg_tables
168            WHERE schemaname IN (#{schemas})
169         SQL
170       end
171
172       def indexes(table_name, name = nil) #:nodoc:
173         result = query(<<-SQL, name)
174           SELECT i.relname, d.indisunique, a.attname
175             FROM pg_class t, pg_class i, pg_index d, pg_attribute a
176            WHERE i.relkind = 'i'
177              AND d.indexrelid = i.oid
178              AND d.indisprimary = 'f'
179              AND t.oid = d.indrelid
180              AND t.relname = '#{table_name}'
181              AND a.attrelid = t.oid
182              AND ( d.indkey[0]=a.attnum OR d.indkey[1]=a.attnum
183                 OR d.indkey[2]=a.attnum OR d.indkey[3]=a.attnum
184                 OR d.indkey[4]=a.attnum OR d.indkey[5]=a.attnum
185                 OR d.indkey[6]=a.attnum OR d.indkey[7]=a.attnum
186                 OR d.indkey[8]=a.attnum OR d.indkey[9]=a.attnum )
187           ORDER BY i.relname
188         SQL
189
190         current_index = nil
191         indexes = []
192
193         result.each do |row|
194           if current_index != row[0]
195             indexes << IndexDefinition.new(table_name, row[0], row[1] == "t", [])
196             current_index = row[0]
197           end
198
199           indexes.last.columns << row[2]
200         end
201
202         indexes
203       end
204
205       def columns(table_name, name = nil) #:nodoc:
206         column_definitions(table_name).collect do |name, type, default, notnull|
207           Column.new(name, default_value(default), translate_field_type(type),
208             notnull == "f")
209         end
210       end
211
212       # Set the schema search path to a string of comma-separated schema names.
213       # Names beginning with $ are quoted (e.g. $user => '$user')
214       # See http://www.postgresql.org/docs/8.0/interactive/ddl-schemas.html
215       def schema_search_path=(schema_csv) #:nodoc:
216         if schema_csv
217           execute "SET search_path TO #{schema_csv}"
218           @schema_search_path = nil
219         end
220       end
221
222       def schema_search_path #:nodoc:
223         @schema_search_path ||= query('SHOW search_path')[0][0]
224       end
225
226       def default_sequence_name(table_name, pk = nil)
227         default_pk, default_seq = pk_and_sequence_for(table_name)
228         default_seq || "#{table_name}_#{pk || default_pk || 'id'}_seq"
229       end
230
231       # Resets sequence to the max value of the table's pk if present.
232       def reset_pk_sequence!(table, pk = nil, sequence = nil)
233         unless pk and sequence
234           default_pk, default_sequence = pk_and_sequence_for(table)
235           pk ||= default_pk
236           sequence ||= default_sequence
237         end
238         if pk
239           if sequence
240             select_value <<-end_sql, 'Reset sequence'
241               SELECT setval('#{sequence}', (SELECT COALESCE(MAX(#{pk})+(SELECT increment_by FROM #{sequence}), (SELECT min_value FROM #{sequence})) FROM #{table}), false)
242             end_sql
243           else
244             @logger.warn "#{table} has primary key #{pk} with no default sequence" if @logger
245           end
246         end
247       end
248
249       # Find a table's primary key and sequence.
250       def pk_and_sequence_for(table)
251         # First try looking for a sequence with a dependency on the
252         # given table's primary key.
253         result = execute(<<-end_sql, 'PK and serial sequence')[0]
254           SELECT attr.attname, name.nspname, seq.relname
255           FROM pg_class      seq,
256                pg_attribute  attr,
257                pg_depend     dep,
258                pg_namespace  name,
259                pg_constraint cons
260           WHERE seq.oid           = dep.objid
261             AND seq.relnamespace  = name.oid
262             AND seq.relkind       = 'S'
263             AND attr.attrelid     = dep.refobjid
264             AND attr.attnum       = dep.refobjsubid
265             AND attr.attrelid     = cons.conrelid
266             AND attr.attnum       = cons.conkey[1]
267             AND cons.contype      = 'p'
268             AND dep.refobjid      = '#{table}'::regclass
269         end_sql
270
271         if result.nil? or result.empty?
272           # If that fails, try parsing the primary key's default value.
273           # Support the 7.x and 8.0 nextval('foo'::text) as well as
274           # the 8.1+ nextval('foo'::regclass).
275           # TODO: assumes sequence is in same schema as table.
276           result = execute(<<-end_sql, 'PK and custom sequence')[0]
277             SELECT attr.attname, name.nspname, split_part(def.adsrc, '\\\'', 2)
278             FROM pg_class       t
279             JOIN pg_namespace   name ON (t.relnamespace = name.oid)
280             JOIN pg_attribute   attr ON (t.oid = attrelid)
281             JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
282             JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
283             WHERE t.oid = '#{table}'::regclass
284               AND cons.contype = 'p'
285               AND def.adsrc ~* 'nextval'
286           end_sql
287         end
288         # check for existence of . in sequence name as in public.foo_sequence.  if it does not exist, join the current namespace
289         result.last['.'] ? [result.first, result.last] : [result.first, "#{result[1]}.#{result[2]}"]
290       rescue
291         nil
292       end
293
294       def rename_table(name, new_name)
295         execute "ALTER TABLE #{name} RENAME TO #{new_name}"
296       end
297            
298       def add_column(table_name, column_name, type, options = {})
299         native_type = native_database_types[type]
300         sql_commands = ["ALTER TABLE #{table_name} ADD #{column_name} #{type_to_sql(type, options[:limit])}"]
301         if options[:default]
302           sql_commands << "ALTER TABLE #{table_name} ALTER #{column_name} SET DEFAULT '#{options[:default]}'"
303         end
304         if options[:null] == false
305           sql_commands << "ALTER TABLE #{table_name} ALTER #{column_name} SET NOT NULL"
306         end
307         sql_commands.each { |cmd| execute(cmd) }
308       end
309
310       def change_column(table_name, column_name, type, options = {}) #:nodoc:
311         execute = "ALTER TABLE #{table_name} ALTER  #{column_name} TYPE #{type}"
312         change_column_default(table_name, column_name, options[:default]) unless options[:default].nil?
313       end     
314
315       def change_column_default(table_name, column_name, default) #:nodoc:
316         execute "ALTER TABLE #{table_name} ALTER COLUMN #{column_name} SET DEFAULT '#{default}'"
317       end
318      
319       def rename_column(table_name, column_name, new_column_name) #:nodoc:
320         execute "ALTER TABLE #{table_name} RENAME COLUMN #{column_name} TO #{new_column_name}"
321       end
322
323       def remove_index(table_name, options) #:nodoc:
324         if Hash === options
325           index_name = options[:name]
326         else
327           index_name = "#{table_name}_#{options}_index"
328         end
329
330         execute "DROP INDEX #{index_name}"
331       end     
332
333
334       private
335         BYTEA_COLUMN_TYPE_OID = 17
336
337         def configure_connection
338           if @config[:encoding]
339             execute("SET client_encoding TO '#{@config[:encoding]}'")
340           end
341           if @config[:min_messages]
342             execute("SET client_min_messages TO '#{@config[:min_messages]}'")
343           end
344         end
345
346         def last_insert_id(table, sequence_name)
347           Integer(select_value("SELECT currval('#{sequence_name}')"))
348         end
349
350         def select(sql, name = nil)
351           res = execute(sql, name)
352           results = res.result           
353           rows = []
354           if results.length > 0
355             fields = res.fields
356             results.each do |row|
357               hashed_row = {}
358               row.each_index do |cel_index|
359                 column = row[cel_index]
360                 if res.type(cel_index) == BYTEA_COLUMN_TYPE_OID
361                   column = unescape_bytea(column)
362                 end
363                 hashed_row[fields[cel_index]] = column
364               end
365               rows << hashed_row
366             end
367           end
368           return rows
369         end
370
371         def escape_bytea(s)
372           if PGconn.respond_to? :escape_bytea
373             self.class.send(:define_method, :escape_bytea) do |s|
374               PGconn.escape_bytea(s) if s
375             end
376           else
377             self.class.send(:define_method, :escape_bytea) do |s|
378               if s
379                 result = ''
380                 s.each_byte { |c| result << sprintf('\\\\%03o', c) }
381                 result
382               end
383             end
384           end
385           escape_bytea(s)
386         end
387
388         def unescape_bytea(s)
389           if PGconn.respond_to? :unescape_bytea
390             self.class.send(:define_method, :unescape_bytea) do |s|
391               PGconn.unescape_bytea(s) if s
392             end
393           else
394             self.class.send(:define_method, :unescape_bytea) do |s|
395               if s
396                 result = ''
397                 i, max = 0, s.size
398                 while i < max
399                   char = s[i]
400                   if char == ?\\
401                     if s[i+1] == ?\\
402                       char = ?\\
403                       i += 1
404                     else
405                       char = s[i+1..i+3].oct
406                       i += 3
407                     end
408                   end
409                   result << char
410                   i += 1
411                 end
412                 result
413               end
414             end
415           end
416           unescape_bytea(s)
417         end
418        
419         # Query a table's column names, default values, and types.
420         #
421         # The underlying query is roughly:
422         #  SELECT column.name, column.type, default.value
423         #    FROM column LEFT JOIN default
424         #      ON column.table_id = default.table_id
425         #     AND column.num = default.column_num
426         #   WHERE column.table_id = get_table_id('table_name')
427         #     AND column.num > 0
428         #     AND NOT column.is_dropped
429         #   ORDER BY column.num
430         #
431         # If the table name is not prefixed with a schema, the database will
432         # take the first match from the schema search path.
433         #
434         # Query implementation notes:
435         #  - format_type includes the column size constraint, e.g. varchar(50)
436         #  - ::regclass is a function that gives the id for a table name
437         def column_definitions(table_name)
438           query <<-end_sql
439             SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
440               FROM pg_attribute a LEFT JOIN pg_attrdef d
441                 ON a.attrelid = d.adrelid AND a.attnum = d.adnum
442              WHERE a.attrelid = '#{table_name}'::regclass
443                AND a.attnum > 0 AND NOT a.attisdropped
444              ORDER BY a.attnum
445           end_sql
446         end
447
448         # Translate PostgreSQL-specific types into simplified SQL types.
449         # These are special cases; standard types are handled by
450         # ConnectionAdapters::Column#simplified_type.
451         def translate_field_type(field_type)
452           # Match the beginning of field_type since it may have a size constraint on the end.
453           case field_type
454             when /^timestamp/i    then 'datetime'
455             when /^real|^money/i  then 'float'
456             when /^interval/i     then 'string'
457             # geometric types (the line type is currently not implemented in postgresql)
458             when /^(?:point|lseg|box|"?path"?|polygon|circle)/i  then 'string'
459             when /^bytea/i        then 'binary'
460             else field_type       # Pass through standard types.
461           end
462         end
463
464         def default_value(value)
465           # Boolean types
466           return "t" if value =~ /true/i
467           return "f" if value =~ /false/i
468          
469           # Char/String type values
470           return $1 if value =~ /^'(.*)'::(bpchar|text|character varying)$/
471          
472           # Numeric values
473           return value if value =~ /^[0-9]+(\.[0-9]*)?/
474
475           # Date / Time magic values
476           return Time.now.to_s if value =~ /^now\(\)|^\('now'::text\)::(date|timestamp)/i
477
478           # Fixed dates / times
479           return $1 if value =~ /^'(.+)'::(date|timestamp)/
480          
481           # Anything else is blank, some user type, or some function
482           # and we can't know the value of that, so return nil.
483           return nil
484         end
485     end
486   end
487 end
Note: See TracBrowser for help on using the browser.