Selfie segmentation with ML Kit on Android

ML Kit provides an optimized SDK for selfie segmentation.

The Selfie Segmenter assets are statically linked to your app at build time. This will increase your app download size by about 4.5MB and the API latency can vary from 25ms to 65ms depending on the input image size, as measured on a Pixel 4.

Try it out

  • Play around with the sample app to see an example usage of this API.

Before you begin

  1. In your project-level build.gradle file, make sure to include Google's Maven repository in both your buildscript and allprojects sections.
  2. Add the dependencies for the ML Kit Android libraries to your module's app-level gradle file, which is usually app/build.gradle :
  dependencies 
  
 { 
  
 implementation 
  
 ' 
 com 
 . 
 google 
 . 
 mlkit 
 : 
 segmentation 
 - 
 selfie 
 : 
 16.0.0 
 - 
 beta6 
 ' 
 } 
 

1. Create an instance of Segmenter

Segmenter options

To do segmentation on an image, first create an instance of Segmenter by specifying the following options.

Detector Mode

The Segmenter operates in two modes. Be sure you choose the one that matches your use case.

STREAM_MODE (default)

This mode is designed for streaming frames from video or camera. In this mode, the segmenter will leverage results from previous frames to return smoother segmentation results.

SINGLE_IMAGE_MODE

This mode is designed for single images that are not related. In this mode, the segmenter will process each image independently, with no smoothing over frames.

Enable raw size mask

Asks the segmenter to return the raw size mask which matches the model output size.

The raw mask size (e.g. 256x256) is usually smaller than the input image size. Please call SegmentationMask#getWidth() and SegmentationMask#getHeight() to get the mask size when enabling this option.

Without specifying this option, the segmenter will rescale the raw mask to match the input image size. Consider using this option if you want to apply customized rescaling logic or rescaling is not needed for your use case.

Specify the segmenter options:

Kotlin

 val 
  
 options 
  
 = 
  
 SelfieSegmenterOptions 
 . 
 Builder 
 () 
  
 . 
 setDetectorMode 
 ( 
 SelfieSegmenterOptions 
 . 
 STREAM_MODE 
 ) 
  
 . 
 enableRawSizeMask 
 () 
  
 . 
 build 
 () 

Java

 SelfieSegmenterOptions 
  
 options 
  
 = 
  
 new 
  
 SelfieSegmenterOptions 
 . 
 Builder 
 () 
  
 . 
 setDetectorMode 
 ( 
 SelfieSegmenterOptions 
 . 
 STREAM_MODE 
 ) 
  
 . 
 enableRawSizeMask 
 () 
  
 . 
 build 
 (); 

Create an instance of Segmenter . Pass the options you specified:

Kotlin

val segmenter = Segmentation.getClient(options)

Java

Segmenter segmenter = Segmentation.getClient(options);

2. Prepare the input image

To perform segmentation on an image, create an InputImage object from either a Bitmap , media.Image , ByteBuffer , byte array, or a file on the device.

You can create an InputImage object from different sources, each is explained below.

Using a media.Image

To create an InputImage object from a media.Image object, such as when you capture an image from a device's camera, pass the media.Image object and the image's rotation to InputImage.fromMediaImage() .

If you use the CameraX library, the OnImageCapturedListener and ImageAnalysis.Analyzer classes calculate the rotation value for you.

Kotlin

 private 
  
 class 
  
 YourImageAnalyzer 
  
 : 
  
 ImageAnalysis 
 . 
 Analyzer 
  
 { 
  
 override 
  
 fun 
  
 analyze 
 ( 
 imageProxy 
 : 
  
 ImageProxy 
 ) 
  
 { 
  
 val 
  
 mediaImage 
  
 = 
  
 imageProxy 
 . 
 image 
  
 if 
  
 ( 
 mediaImage 
  
 != 
  
 null 
 ) 
  
 { 
  
 val 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromMediaImage 
 ( 
 mediaImage 
 , 
  
 imageProxy 
 . 
 imageInfo 
 . 
 rotationDegrees 
 ) 
  
 // Pass image to an ML Kit Vision API 
  
 // ... 
  
 } 
  
 } 
 } 

Java

 private 
  
 class 
 YourAnalyzer 
  
 implements 
  
 ImageAnalysis 
 . 
 Analyzer 
  
 { 
  
 @Override 
  
 public 
  
 void 
  
 analyze 
 ( 
 ImageProxy 
  
 imageProxy 
 ) 
  
 { 
  
 Image 
  
 mediaImage 
  
 = 
  
 imageProxy 
 . 
 getImage 
 (); 
  
 if 
  
 ( 
 mediaImage 
  
 != 
  
 null 
 ) 
  
 { 
  
 InputImage 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromMediaImage 
 ( 
 mediaImage 
 , 
  
 imageProxy 
 . 
 getImageInfo 
 (). 
 getRotationDegrees 
 ()); 
  
 // Pass image to an ML Kit Vision API 
  
 // ... 
  
 } 
  
 } 
 } 

If you don't use a camera library that gives you the image's rotation degree, you can calculate it from the device's rotation degree and the orientation of camera sensor in the device:

Kotlin

 private 
  
 val 
  
 ORIENTATIONS 
  
 = 
  
 SparseIntArray 
 () 
 init 
  
 { 
  
 ORIENTATIONS 
 . 
 append 
 ( 
 Surface 
 . 
 ROTATION_0 
 , 
  
 0 
 ) 
  
 ORIENTATIONS 
 . 
 append 
 ( 
 Surface 
 . 
 ROTATION_90 
 , 
  
 90 
 ) 
  
 ORIENTATIONS 
 . 
 append 
 ( 
 Surface 
 . 
 ROTATION_180 
 , 
  
 180 
 ) 
  
 ORIENTATIONS 
 . 
 append 
 ( 
 Surface 
 . 
 ROTATION_270 
 , 
  
 270 
 ) 
 } 
 /** 
 * Get the angle by which an image must be rotated given the device's current 
 * orientation. 
 */ 
 @RequiresApi 
 ( 
 api 
  
 = 
  
 Build 
 . 
 VERSION_CODES 
 . 
 LOLLIPOP 
 ) 
 @Throws 
 ( 
 CameraAccessException 
 :: 
 class 
 ) 
 private 
  
 fun 
  
 getRotationCompensation 
 ( 
 cameraId 
 : 
  
 String 
 , 
  
 activity 
 : 
  
 Activity 
 , 
  
 isFrontFacing 
 : 
  
 Boolean 
 ): 
  
 Int 
  
 { 
  
 // Get the device's current rotation relative to its "native" orientation. 
  
 // Then, from the ORIENTATIONS table, look up the angle the image must be 
  
 // rotated to compensate for the device's rotation. 
  
 val 
  
 deviceRotation 
  
 = 
  
 activity 
 . 
 windowManager 
 . 
 defaultDisplay 
 . 
 rotation 
  
 var 
  
 rotationCompensation 
  
 = 
  
 ORIENTATIONS 
 . 
 get 
 ( 
 deviceRotation 
 ) 
  
 // Get the device's sensor orientation. 
  
 val 
  
 cameraManager 
  
 = 
  
 activity 
 . 
 getSystemService 
 ( 
 CAMERA_SERVICE 
 ) 
  
 as 
  
 CameraManager 
  
 val 
  
 sensorOrientation 
  
 = 
  
 cameraManager 
  
 . 
 getCameraCharacteristics 
 ( 
 cameraId 
 ) 
  
 . 
 get 
 ( 
 CameraCharacteristics 
 . 
 SENSOR_ORIENTATION 
 ) 
 !! 
  
 if 
  
 ( 
 isFrontFacing 
 ) 
  
 { 
  
 rotationCompensation 
  
 = 
  
 ( 
 sensorOrientation 
  
 + 
  
 rotationCompensation 
 ) 
  
 % 
  
 360 
  
 } 
  
 else 
  
 { 
  
 // back-facing 
  
 rotationCompensation 
  
 = 
  
 ( 
 sensorOrientation 
  
 - 
  
 rotationCompensation 
  
 + 
  
 360 
 ) 
  
 % 
  
 360 
  
 } 
  
 return 
  
 rotationCompensation 
 } 
  

Java

 private 
  
 static 
  
 final 
  
 SparseIntArray 
  
 ORIENTATIONS 
  
 = 
  
 new 
  
 SparseIntArray 
 (); 
 static 
  
 { 
  
 ORIENTATIONS 
 . 
 append 
 ( 
 Surface 
 . 
 ROTATION_0 
 , 
  
 0 
 ); 
  
 ORIENTATIONS 
 . 
 append 
 ( 
 Surface 
 . 
 ROTATION_90 
 , 
  
 90 
 ); 
  
 ORIENTATIONS 
 . 
 append 
 ( 
 Surface 
 . 
 ROTATION_180 
 , 
  
 180 
 ); 
  
 ORIENTATIONS 
 . 
 append 
 ( 
 Surface 
 . 
 ROTATION_270 
 , 
  
 270 
 ); 
 } 
 /** 
 * Get the angle by which an image must be rotated given the device's current 
 * orientation. 
 */ 
 @RequiresApi 
 ( 
 api 
  
 = 
  
 Build 
 . 
 VERSION_CODES 
 . 
 LOLLIPOP 
 ) 
 private 
  
 int 
  
 getRotationCompensation 
 ( 
 String 
  
 cameraId 
 , 
  
 Activity 
  
 activity 
 , 
  
 boolean 
  
 isFrontFacing 
 ) 
  
 throws 
  
 CameraAccessException 
  
 { 
  
 // Get the device's current rotation relative to its "native" orientation. 
  
 // Then, from the ORIENTATIONS table, look up the angle the image must be 
  
 // rotated to compensate for the device's rotation. 
  
 int 
  
 deviceRotation 
  
 = 
  
 activity 
 . 
 getWindowManager 
 (). 
 getDefaultDisplay 
 (). 
 getRotation 
 (); 
  
 int 
  
 rotationCompensation 
  
 = 
  
 ORIENTATIONS 
 . 
 get 
 ( 
 deviceRotation 
 ); 
  
 // Get the device's sensor orientation. 
  
 CameraManager 
  
 cameraManager 
  
 = 
  
 ( 
 CameraManager 
 ) 
  
 activity 
 . 
 getSystemService 
 ( 
 CAMERA_SERVICE 
 ); 
  
 int 
  
 sensorOrientation 
  
 = 
  
 cameraManager 
  
 . 
 getCameraCharacteristics 
 ( 
 cameraId 
 ) 
  
 . 
 get 
 ( 
 CameraCharacteristics 
 . 
 SENSOR_ORIENTATION 
 ); 
  
 if 
  
 ( 
 isFrontFacing 
 ) 
  
 { 
  
 rotationCompensation 
  
 = 
  
 ( 
 sensorOrientation 
  
 + 
  
 rotationCompensation 
 ) 
  
 % 
  
 360 
 ; 
  
 } 
  
 else 
  
 { 
  
 // back-facing 
  
 rotationCompensation 
  
 = 
  
 ( 
 sensorOrientation 
  
 - 
  
 rotationCompensation 
  
 + 
  
 360 
 ) 
  
 % 
  
 360 
 ; 
  
 } 
  
 return 
  
 rotationCompensation 
 ; 
 } 

Then, pass the media.Image object and the rotation degree value to InputImage.fromMediaImage() :

Kotlin

 val 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromMediaImage 
 ( 
 mediaImage 
 , 
  
 rotation 
 ) 
  

Java

 InputImage 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromMediaImage 
 ( 
 mediaImage 
 , 
  
 rotation 
 ); 

Using a file URI

To create an InputImage object from a file URI, pass the app context and file URI to InputImage.fromFilePath() . This is useful when you use an ACTION_GET_CONTENT intent to prompt the user to select an image from their gallery app.

Kotlin

 val 
  
 image 
 : 
  
 InputImage 
 try 
  
 { 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromFilePath 
 ( 
 context 
 , 
  
 uri 
 ) 
 } 
  
 catch 
  
 ( 
 e 
 : 
  
 IOException 
 ) 
  
 { 
  
 e 
 . 
 printStackTrace 
 () 
 } 
  

Java

 InputImage 
  
 image 
 ; 
 try 
  
 { 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromFilePath 
 ( 
 context 
 , 
  
 uri 
 ); 
 } 
  
 catch 
  
 ( 
 IOException 
  
 e 
 ) 
  
 { 
  
 e 
 . 
 printStackTrace 
 (); 
 } 

Using a ByteBuffer or ByteArray

To create an InputImage object from a ByteBuffer or a ByteArray , first calculate the image rotation degree as previously described for media.Image input. Then, create the InputImage object with the buffer or array, together with image's height, width, color encoding format, and rotation degree:

Kotlin

 val 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromByteBuffer 
 ( 
  
 byteBuffer 
 , 
  
 /* image width */ 
  
 480 
 , 
  
 /* image height */ 
  
 360 
 , 
  
 rotationDegrees 
 , 
  
 InputImage 
 . 
 IMAGE_FORMAT_NV21 
  
 // or IMAGE_FORMAT_YV12 
 ) 
  
 // Or: 
 val 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromByteArray 
 ( 
  
 byteArray 
 , 
  
 /* image width */ 
  
 480 
 , 
  
 /* image height */ 
  
 360 
 , 
  
 rotationDegrees 
 , 
  
 InputImage 
 . 
 IMAGE_FORMAT_NV21 
  
 // or IMAGE_FORMAT_YV12 
 ) 
  

Java

 InputImage 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromByteBuffer 
 ( 
 byteBuffer 
 , 
  
 /* image width */ 
  
 480 
 , 
  
 /* image height */ 
  
 360 
 , 
  
 rotationDegrees 
 , 
  
 InputImage 
 . 
 IMAGE_FORMAT_NV21 
  
 // or IMAGE_FORMAT_YV12 
 ); 
  
 // Or: 
 InputImage 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromByteArray 
 ( 
  
 byteArray 
 , 
  
 /* image width */ 
 480 
 , 
  
 /* image height */ 
 360 
 , 
  
 rotation 
 , 
  
 InputImage 
 . 
 IMAGE_FORMAT_NV21 
  
 // or IMAGE_FORMAT_YV12 
 ); 
  

Using a Bitmap

To create an InputImage object from a Bitmap object, make the following declaration:

Kotlin

 val 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromBitmap 
 ( 
 bitmap 
 , 
  
 0 
 ) 
  

Java

 InputImage 
  
 image 
  
 = 
  
 InputImage 
 . 
 fromBitmap 
 ( 
 bitmap 
 , 
  
 rotationDegree 
 ); 
  

The image is represented by a Bitmap object together with rotation degrees.

3. Process the image

Pass the prepared InputImage object to the Segmenter 's process method.

Kotlin

Task<SegmentationMask> result = segmenter.process(image)
       .addOnSuccessListener { results ->
           // Task completed successfully
           // ...
       }
       .addOnFailureListener { e ->
           // Task failed with an exception
           // ...
       }

Java

 Task<SegmentationMask> 
  
 result 
  
 = 
  
 segmenter 
 . 
 process 
 ( 
 image 
 ) 
  
 . 
 addOnSuccessListener 
 ( 
  
 new 
  
 OnSuccessListener<SegmentationMask> 
 () 
  
 { 
  
 @Override 
  
 public 
  
 void 
  
 onSuccess 
 ( 
 SegmentationMask 
  
 mask 
 ) 
  
 { 
  
 // 
  
 Task 
  
 completed 
  
 successfully 
  
 // 
  
 ... 
  
 } 
  
 } 
 ) 
  
 . 
 addOnFailureListener 
 ( 
  
 new 
  
 OnFailureListener 
 () 
  
 { 
  
 @Override 
  
 public 
  
 void 
  
 onFailure 
 ( 
 @NonNull 
  
 Exception 
  
 e 
 ) 
  
 { 
  
 // 
  
 Task 
  
 failed 
  
 with 
  
 an 
  
 exception 
  
 // 
  
 ... 
  
 } 
  
 } 
 ); 

4. Get the segmentation result

You can get the segmentation result as follows:

Kotlin

 val 
  
 mask 
  
 = 
  
 segmentationMask 
 . 
 getBuffer 
 () 
 val 
  
 maskWidth 
  
 = 
  
 segmentationMask 
 . 
 getWidth 
 () 
 val 
  
 maskHeight 
  
 = 
  
 segmentationMask 
 . 
 getHeight 
 () 
 for 
  
 ( 
 val 
  
 y 
  
 = 
  
 0 
 ; 
  
 y 
  
 < 
  
 maskHeight 
 ; 
  
 y 
 ++ 
 ) 
  
 { 
  
 for 
  
 ( 
 val 
  
 x 
  
 = 
  
 0 
 ; 
  
 x 
  
 < 
  
 maskWidth 
 ; 
  
 x 
 ++ 
 ) 
  
 { 
  
 // Gets the confidence of the (x,y) pixel in the mask being in the foreground. 
  
 val 
  
 foregroundConfidence 
  
 = 
  
 mask 
 . 
 getFloat 
 () 
  
 } 
 } 

Java

 ByteBuffer 
  
 mask 
  
 = 
  
 segmentationMask 
 . 
 getBuffer 
 (); 
 int 
  
 maskWidth 
  
 = 
  
 segmentationMask 
 . 
 getWidth 
 (); 
 int 
  
 maskHeight 
  
 = 
  
 segmentationMask 
 . 
 getHeight 
 (); 
 for 
  
 ( 
 int 
  
 y 
  
 = 
  
 0 
 ; 
  
 y 
 < 
 maskHeight 
 ; 
  
 y 
 ++ 
 ) 
  
 { 
  
 for 
  
 ( 
 int 
  
 x 
  
 = 
  
 0 
 ; 
  
 x 
 < 
 maskWidth 
 ; 
  
 x 
 ++ 
 ) 
  
 { 
  
 // Gets the confidence of the (x,y) pixel in the mask being in the foreground. 
  
 float 
  
 foregroundConfidence 
  
 = 
  
 mask 
 . 
 getFloat 
 (); 
  
 } 
 } 

For a full example of how to use the segmentation results, please see the ML Kit quickstart sample .

Tips to improve performance

The quality of your results depends on the quality of the input image:

  • For ML Kit to get an accurate segmentation result, the image should be at least 256x256 pixels.
  • Poor image focus can also impact accuracy. If you don't get acceptable results, ask the user to recapture the image.

If you want to use segmentation in a real-time application, follow these guidelines to achieve the best frame rates:

  • Use STREAM_MODE .
  • Consider capturing images at a lower resolution. However, also keep in mind this API's image dimension requirements.
  • Consider enabling the raw size mask option and combining all rescaling logic together. For example, instead of letting the API to rescale the mask to match your input image size first and then you rescale it again to match the View size for display, just request the raw size mask, and combine these two steps into one.
  • If you use the Camera or camera2 API, throttle calls to the detector. If a new video frame becomes available while the detector is running, drop the frame. See the VisionProcessorBase class in the quickstart sample app for an example.
  • If you use the CameraX API, be sure that backpressure strategy is set to its default value ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST . This guarantees only one image will be delivered for analysis at a time. If more images are produced when the analyzer is busy, they will be dropped automatically and not queued for delivery. Once the image being analyzed is closed by calling ImageProxy.close(), the next latest image will be delivered.
  • If you use the output of the detector to overlay graphics on the input image, first get the result from ML Kit, then render the image and overlay in a single step. This renders to the display surface only once for each input frame. See the CameraSourcePreview and GraphicOverlay classes in the quickstart sample app for an example.
  • If you use the Camera2 API, capture images in ImageFormat.YUV_420_888 format. If you use the older Camera API, capture images in ImageFormat.NV21 format.
Create a Mobile Website
View Site in Mobile | Classic
Share by: