]> git.wincent.com - mkdtemp.git/commitdiff
Teach mkdtemp to take optional block parameter
authorWincent Colaiuta <win@wincent.com>
Wed, 28 Jul 2010 20:21:04 +0000 (22:21 +0200)
committerWincent Colaiuta <win@wincent.com>
Wed, 28 Jul 2010 20:21:04 +0000 (22:21 +0200)
Signed-off-by: Wincent Colaiuta <win@wincent.com>
ext/mkdtemp.c
spec/mkdtemp_spec.rb

index b17fe1e5d73a4cd94c1116428bb7b31b55cc0e7c..cd6f85640dc7dc8611737c6d3fc8207be108c28e 100644 (file)
 #include <unistd.h>
 #include "ruby_compat.h"
 
+// helper function needed by rb_iterate; see:
+//  http://blade.nagaokaut.ac.jp/cgi-bin/scat.rb/ruby/ruby-talk/144100
+VALUE call_chdir(VALUE dir)
+{
+    return rb_funcall(rb_cDir, rb_intern("chdir"), 1, dir);
+}
+
+// helper function needed by rb_iterate
+VALUE yield_block(VALUE ignored, VALUE block)
+{
+    return rb_funcall(block, rb_intern("call"), 0);
+}
+
 // call-seq:
 //     Dir.mkdtemp([string])   -> String or nil
 //
 // and overwriting the template in-place; if no template is supplied then
 // "/tmp/temp.XXXXXX" is used as a default.
 //
+// If supplied a block, performs a Dir.chdir into the created directory and
+// yields to the block:
+//
+//      # this:            # is a shorthand for:
+//      Dir.mkdtemp do     #   dir = Dir.mkdtemp
+//        puts Dir.pwd     #   Dir.chdir dir do
+//      end                #     puts Dir.pwd
+//                         #   end
+//
 // Note that the exact implementation of mkdtemp() may vary depending on the
 // target system. For example, on Mac OS X at the time of writing, the man page
 // states that the template may contain "some number" of "Xs" on the end of the
 // suffix "must be XXXXXX".
 static VALUE dir_mkdtemp_m(int argc, VALUE *argv, VALUE self)
 {
-    VALUE template;
+    VALUE template, block;
     char *c_template;
     char *path;
 
     // process arguments
-    if (rb_scan_args(argc, argv, "01", &template) == 0) // check for 0 mandatory arguments, 1 optional argument
-        template = Qnil;                                // default to nil if no argument passed
+    if (rb_scan_args(argc, argv, "01&", &template, &block) == 0)    // 0 mandatory, 1 optional, 1 block
+        template = Qnil;                                            // default to nil if no argument passed
     if (NIL_P(template))
-        template = rb_str_new2("/tmp/temp.XXXXXX");     // fallback to this template if passed nil
-    SafeStringValue(template);                          // raises if template is tainted and SAFE level > 0
-    template = StringValue(template);                   // duck typing support
+        template = rb_str_new2("/tmp/temp.XXXXXX");                 // fallback to this template if passed nil
+    SafeStringValue(template);                                      // raises if template is tainted and SAFE level > 0
+    template = StringValue(template);                               // duck typing support
 
     // create temporary storage
     c_template = malloc(RSTRING_LEN(template) + 1);
@@ -68,6 +90,10 @@ static VALUE dir_mkdtemp_m(int argc, VALUE *argv, VALUE self)
     free(c_template);
     if (path == NULL)
         rb_raise(rb_eSystemCallError, "mkdtemp failed (error #%d: %s)", errno, strerror(errno));
+
+    // yield to block if given, inside Dir.chdir
+    if (rb_block_given_p() == Qtrue)
+        rb_iterate(call_chdir, template, yield_block, block);
     return template;
 }
 
index 10a2f24f732989265da181d38dad6a8fbcbf4890..a1fdd8b728b0abe336459edc1fd767e252ce2368 100644 (file)
@@ -84,4 +84,33 @@ describe 'Dir.mkdtemp' do
     path = Dir.mkdtemp '/tmp/test.XXXXXX'
     path.should match(%r{\A/tmp/test\..{6}\z})
   end
+
+  context 'with optional block parameter' do
+    it 'performs a Dir.chdir into the created directory' do
+      cd_path = nil
+      path = Dir.mkdtemp do
+        cd_path = Dir.pwd
+      end
+      Pathname.new(path).realpath.should == Pathname.new(cd_path).realpath
+    end
+
+    it 'lets exceptions bubble up to calling context' do
+      expect do
+        Dir.mkdtemp do
+          raise 'bubble'
+        end
+      end.to raise_error(/bubble/)
+    end
+
+    it 'preserves current directory even when an exception occurs' do
+      path = Dir.pwd
+      begin
+        Dir.mkdtemp do
+          raise 'error'
+        end
+      rescue
+      end
+      Dir.pwd.should == path
+    end
+  end
 end