Create a training pipeline

Creates a training pipeline using the create_training_pipeline method.

Code sample

Java

Before trying this sample, follow the Java setup instructions in the Vertex AI quickstart using client libraries . For more information, see the Vertex AI Java API reference documentation .

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment .

  import 
  
 com.google.cloud.aiplatform.v1. DeployedModelRef 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. EnvVar 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. FilterSplit 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. FractionSplit 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. InputDataConfig 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. LocationName 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. Model 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. Model 
. ExportFormat 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. ModelContainerSpec 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. PipelineServiceClient 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. PipelineServiceSettings 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. Port 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. PredefinedSplit 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. PredictSchemata 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. TimestampSplit 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. TrainingPipeline 
 
 ; 
 import 
  
 com.google.protobuf. Value 
 
 ; 
 import 
  
 com.google.protobuf.util. JsonFormat 
 
 ; 
 import 
  
 com.google.rpc. Status 
 
 ; 
 import 
  
 java.io.IOException 
 ; 
 public 
  
 class 
 CreateTrainingPipelineSample 
  
 { 
  
 public 
  
 static 
  
 void 
  
 main 
 ( 
 String 
 [] 
  
 args 
 ) 
  
 throws 
  
 IOException 
  
 { 
  
 // TODO(developer): Replace these variables before running the sample. 
  
 String 
  
 trainingPipelineDisplayName 
  
 = 
  
 "YOUR_TRAINING_PIPELINE_DISPLAY_NAME" 
 ; 
  
 String 
  
 project 
  
 = 
  
 "YOUR_PROJECT_ID" 
 ; 
  
 String 
  
 datasetId 
  
 = 
  
 "YOUR_DATASET_ID" 
 ; 
  
 String 
  
 trainingTaskDefinition 
  
 = 
  
 "YOUR_TRAINING_TASK_DEFINITION" 
 ; 
  
 String 
  
 modelDisplayName 
  
 = 
  
 "YOUR_MODEL_DISPLAY_NAME" 
 ; 
  
 createTrainingPipelineSample 
 ( 
  
 project 
 , 
  
 trainingPipelineDisplayName 
 , 
  
 datasetId 
 , 
  
 trainingTaskDefinition 
 , 
  
 modelDisplayName 
 ); 
  
 } 
  
 static 
  
 void 
  
 createTrainingPipelineSample 
 ( 
  
 String 
  
 project 
 , 
  
 String 
  
 trainingPipelineDisplayName 
 , 
  
 String 
  
 datasetId 
 , 
  
 String 
  
 trainingTaskDefinition 
 , 
  
 String 
  
 modelDisplayName 
 ) 
  
 throws 
  
 IOException 
  
 { 
  
  PipelineServiceSettings 
 
  
 pipelineServiceSettings 
  
 = 
  
  PipelineServiceSettings 
 
 . 
 newBuilder 
 () 
  
 . 
 setEndpoint 
 ( 
 "us-central1-aiplatform.googleapis.com:443" 
 ) 
  
 . 
 build 
 (); 
  
 // Initialize client that will be used to send requests. This client only needs to be created 
  
 // once, and can be reused for multiple requests. After completing all of your requests, call 
  
 // the "close" method on the client to safely clean up any remaining background resources. 
  
 try 
  
 ( 
  PipelineServiceClient 
 
  
 pipelineServiceClient 
  
 = 
  
  PipelineServiceClient 
 
 . 
 create 
 ( 
 pipelineServiceSettings 
 )) 
  
 { 
  
 String 
  
 location 
  
 = 
  
 "us-central1" 
 ; 
  
  LocationName 
 
  
 locationName 
  
 = 
  
  LocationName 
 
 . 
 of 
 ( 
 project 
 , 
  
 location 
 ); 
  
 String 
  
 jsonString 
  
 = 
  
 "{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000," 
  
 + 
  
 " \"disableEarlyStopping\": false}" 
 ; 
  
  Value 
 
 . 
 Builder 
  
 trainingTaskInputs 
  
 = 
  
  Value 
 
 . 
 newBuilder 
 (); 
  
  JsonFormat 
 
 . 
 parser 
 (). 
 merge 
 ( 
 jsonString 
 , 
  
 trainingTaskInputs 
 ); 
  
  InputDataConfig 
 
  
 trainingInputDataConfig 
  
 = 
  
  InputDataConfig 
 
 . 
 newBuilder 
 (). 
  setDatasetId 
 
 ( 
 datasetId 
 ). 
 build 
 (); 
  
  Model 
 
  
 model 
  
 = 
  
  Model 
 
 . 
 newBuilder 
 (). 
 setDisplayName 
 ( 
 modelDisplayName 
 ). 
 build 
 (); 
  
  TrainingPipeline 
 
  
 trainingPipeline 
  
 = 
  
  TrainingPipeline 
 
 . 
 newBuilder 
 () 
  
 . 
 setDisplayName 
 ( 
 trainingPipelineDisplayName 
 ) 
  
 . 
  setTrainingTaskDefinition 
 
 ( 
 trainingTaskDefinition 
 ) 
  
 . 
  setTrainingTaskInputs 
 
 ( 
 trainingTaskInputs 
 ) 
  
 . 
  setInputDataConfig 
 
 ( 
 trainingInputDataConfig 
 ) 
  
 . 
  setModelToUpload 
 
 ( 
 model 
 ) 
  
 . 
 build 
 (); 
  
  TrainingPipeline 
 
  
 trainingPipelineResponse 
  
 = 
  
 pipelineServiceClient 
 . 
 createTrainingPipeline 
 ( 
 locationName 
 , 
  
 trainingPipeline 
 ); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Create Training Pipeline Response" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Name: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getName 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Display Name: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getDisplayName 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "Training Task Definition %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getTrainingTaskDefinition 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "Training Task Inputs: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getTrainingTaskInputs 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "Training Task Metadata: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getTrainingTaskMetadata 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "State: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getState 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Create Time: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getCreateTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "StartTime %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getStartTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "End Time: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getEndTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Update Time: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getUpdateTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Labels: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getLabelsMap 
 
 ()); 
  
  InputDataConfig 
 
  
 inputDataConfig 
  
 = 
  
 trainingPipelineResponse 
 . 
  getInputDataConfig 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Input Data Config" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Dataset Id: %s" 
 , 
  
 inputDataConfig 
 . 
  getDatasetId 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Annotations Filter: %s\n" 
 , 
  
 inputDataConfig 
 . 
  getAnnotationsFilter 
 
 ()); 
  
  FractionSplit 
 
  
 fractionSplit 
  
 = 
  
 inputDataConfig 
 . 
  getFractionSplit 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Fraction Split" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Training Fraction: %s\n" 
 , 
  
 fractionSplit 
 . 
  getTrainingFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Validation Fraction: %s\n" 
 , 
  
 fractionSplit 
 . 
  getValidationFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Test Fraction: %s\n" 
 , 
  
 fractionSplit 
 . 
  getTestFraction 
 
 ()); 
  
  FilterSplit 
 
  
 filterSplit 
  
 = 
  
 inputDataConfig 
 . 
  getFilterSplit 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Filter Split" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Training Filter: %s\n" 
 , 
  
 filterSplit 
 . 
  getTrainingFilter 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Validation Filter: %s\n" 
 , 
  
 filterSplit 
 . 
  getValidationFilter 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Test Filter: %s\n" 
 , 
  
 filterSplit 
 . 
  getTestFilter 
 
 ()); 
  
  PredefinedSplit 
 
  
 predefinedSplit 
  
 = 
  
 inputDataConfig 
 . 
  getPredefinedSplit 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Predefined Split" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Key: %s\n" 
 , 
  
 predefinedSplit 
 . 
  getKey 
 
 ()); 
  
  TimestampSplit 
 
  
 timestampSplit 
  
 = 
  
 inputDataConfig 
 . 
  getTimestampSplit 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Timestamp Split" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Training Fraction: %s\n" 
 , 
  
 timestampSplit 
 . 
  getTrainingFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Validation Fraction: %s\n" 
 , 
  
 timestampSplit 
 . 
  getValidationFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Test Fraction: %s\n" 
 , 
  
 timestampSplit 
 . 
  getTestFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Key: %s\n" 
 , 
  
 timestampSplit 
 . 
  getKey 
 
 ()); 
  
  Model 
 
  
 modelResponse 
  
 = 
  
 trainingPipelineResponse 
 . 
  getModelToUpload 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Model To Upload" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Name: %s\n" 
 , 
  
 modelResponse 
 . 
  getName 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Display Name: %s\n" 
 , 
  
 modelResponse 
 . 
  getDisplayName 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Description: %s\n" 
 , 
  
 modelResponse 
 . 
  getDescription 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Metadata Schema Uri: %s\n" 
 , 
  
 modelResponse 
 . 
  getMetadataSchemaUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Metadata: %s\n" 
 , 
  
 modelResponse 
 . 
  getMetadata 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Training Pipeline: %s\n" 
 , 
  
 modelResponse 
 . 
  getTrainingPipeline 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Artifact Uri: %s\n" 
 , 
  
 modelResponse 
 . 
  getArtifactUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "Supported Deployment Resources Types: %s\n" 
 , 
  
 modelResponse 
 . 
  getSupportedDeploymentResourcesTypesList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "Supported Input Storage Formats: %s\n" 
 , 
  
 modelResponse 
 . 
  getSupportedInputStorageFormatsList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "Supported Output Storage Formats: %s\n" 
 , 
  
 modelResponse 
 . 
  getSupportedOutputStorageFormatsList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Create Time: %s\n" 
 , 
  
 modelResponse 
 . 
  getCreateTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Update Time: %s\n" 
 , 
  
 modelResponse 
 . 
  getUpdateTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Labels: %sn\n" 
 , 
  
 modelResponse 
 . 
  getLabelsMap 
 
 ()); 
  
  PredictSchemata 
 
  
 predictSchemata 
  
 = 
  
 modelResponse 
 . 
  getPredictSchemata 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Predict Schemata" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Instance Schema Uri: %s\n" 
 , 
  
 predictSchemata 
 . 
  getInstanceSchemaUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Parameters Schema Uri: %s\n" 
 , 
  
 predictSchemata 
 . 
  getParametersSchemaUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Prediction Schema Uri: %s\n" 
 , 
  
 predictSchemata 
 . 
  getPredictionSchemaUri 
 
 ()); 
  
 for 
  
 ( 
  ExportFormat 
 
  
 exportFormat 
  
 : 
  
 modelResponse 
 . 
  getSupportedExportFormatsList 
 
 ()) 
  
 { 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Supported Export Format" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Id: %s\n" 
 , 
  
 exportFormat 
 . 
 getId 
 ()); 
  
 } 
  
  ModelContainerSpec 
 
  
 modelContainerSpec 
  
 = 
  
 modelResponse 
 . 
  getContainerSpec 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Container Spec" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Image Uri: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getImageUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Command: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getCommandList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Args: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getArgsList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Predict Route: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getPredictRoute 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Health Route: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getHealthRoute 
 
 ()); 
  
 for 
  
 ( 
  EnvVar 
 
  
 envVar 
  
 : 
  
 modelContainerSpec 
 . 
  getEnvList 
 
 ()) 
  
 { 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Env" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Name: %s\n" 
 , 
  
 envVar 
 . 
 getName 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Value: %s\n" 
 , 
  
 envVar 
 . 
 getValue 
 ()); 
  
 } 
  
 for 
  
 ( 
  Port 
 
  
 port 
  
 : 
  
 modelContainerSpec 
 . 
  getPortsList 
 
 ()) 
  
 { 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Port" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Container Port: %s\n" 
 , 
  
 port 
 . 
 getContainerPort 
 ()); 
  
 } 
  
 for 
  
 ( 
  DeployedModelRef 
 
  
 deployedModelRef 
  
 : 
  
 modelResponse 
 . 
  getDeployedModelsList 
 
 ()) 
  
 { 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Deployed Model" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Endpoint: %s\n" 
 , 
  
 deployedModelRef 
 . 
 getEndpoint 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Deployed Model Id: %s\n" 
 , 
  
 deployedModelRef 
 . 
 getDeployedModelId 
 ()); 
  
 } 
  
  Status 
 
  
 status 
  
 = 
  
 trainingPipelineResponse 
 . 
  getError 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Error" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Code: %s\n" 
 , 
  
 status 
 . 
  getCode 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "Message: %s\n" 
 , 
  
 status 
 . 
  getMessage 
 
 ()); 
  
 } 
  
 } 
 } 
 

Python

Before trying this sample, follow the Python setup instructions in the Vertex AI quickstart using client libraries . For more information, see the Vertex AI Python API reference documentation .

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment .

  from 
  
 google.cloud 
  
 import 
 aiplatform 
 from 
  
 google.protobuf 
  
 import 
 json_format 
 from 
  
 google.protobuf.struct_pb2 
  
 import 
 Value 
 def 
  
 create_training_pipeline_sample 
 ( 
 project 
 : 
 str 
 , 
 display_name 
 : 
 str 
 , 
 training_task_definition 
 : 
 str 
 , 
 dataset_id 
 : 
 str 
 , 
 model_display_name 
 : 
 str 
 , 
 location 
 : 
 str 
 = 
 "us-central1" 
 , 
 api_endpoint 
 : 
 str 
 = 
 "us-central1-aiplatform.googleapis.com" 
 , 
 ): 
 # The AI Platform services require regional API endpoints. 
 client_options 
 = 
 { 
 "api_endpoint" 
 : 
 api_endpoint 
 } 
 # Initialize client that will be used to create and send requests. 
 # This client only needs to be created once, and can be reused for multiple requests. 
 client 
 = 
 aiplatform 
 . 
 gapic 
 . 
  PipelineServiceClient 
 
 ( 
 client_options 
 = 
 client_options 
 ) 
 training_task_inputs_dict 
 = 
 { 
 "multiLabel" 
 : 
 True 
 , 
 "modelType" 
 : 
 "CLOUD" 
 , 
 "budgetMilliNodeHours" 
 : 
 8000 
 , 
 "disableEarlyStopping" 
 : 
 False 
 , 
 } 
 training_task_inputs 
 = 
 json_format 
 . 
 ParseDict 
 ( 
 training_task_inputs_dict 
 , 
 Value 
 ()) 
 training_pipeline 
 = 
 { 
 "display_name" 
 : 
 display_name 
 , 
 "training_task_definition" 
 : 
 training_task_definition 
 , 
 "training_task_inputs" 
 : 
 training_task_inputs 
 , 
 "input_data_config" 
 : 
 { 
 "dataset_id" 
 : 
 dataset_id 
 }, 
 "model_to_upload" 
 : 
 { 
 "display_name" 
 : 
 model_display_name 
 }, 
 } 
 parent 
 = 
 f 
 "projects/ 
 { 
 project 
 } 
 /locations/ 
 { 
 location 
 } 
 " 
 response 
 = 
 client 
 . 
  create_training_pipeline 
 
 ( 
 parent 
 = 
 parent 
 , 
 training_pipeline 
 = 
 training_pipeline 
 ) 
 print 
 ( 
 "response:" 
 , 
 response 
 ) 
 

What's next

To search and filter code samples for other Google Cloud products, see the Google Cloud sample browser .

Design a Mobile Site
View Site in Mobile | Classic
Share by: