Train a text entity extraction model

This page shows you how to train an AutoML entity extraction model from a text dataset using either the Google Cloud console or the Vertex AI API.

Before you begin

Before you can train a text entity extraction model, you must complete the following:

Train an AutoML model

Google Cloud console

  1. In the Google Cloud console, in the Vertex AI section, go to the Datasetspage.

    Go to the Datasets page

  2. Click the name of the dataset you want to use to train your model to open its details page.

  3. Select the annotation set you want to use for this model.

  4. Click Train new model.

  5. For the training method, select AutoML .

  6. Click Continue .

  7. Enter a name for the model.

  8. If you want manually set how your training data is split, expand Advanced options and select a data split option. Learn more .

  9. Click Start Training .

    Model training can take many hours, depending on the size and complexity of your data and your training budget, if you specified one. You can close this tab and return to it later. You will receive an email when your model has completed training.

API

Select a tab for your language or environment:

REST

Create a TrainingPipeline object to train a model.

Before using any of the request data, make the following replacements:

  • LOCATION : The region where the model will be created, such as us-central1
  • PROJECT : Your project ID
  • MODEL_DISPLAY_NAME : Name for the model as it appears in the user interface
  • DATASET_ID : The ID for the dataset
  • PROJECT_NUMBER : Your project's automatically generated project number

HTTP method and URL:

POST https:// LOCATION 
-aiplatform.googleapis.com/v1/projects/ PROJECT 
/locations/ LOCATION 
/trainingPipelines

Request JSON body:

{
  "displayName": " MODEL_DISPLAY_NAME 
",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_extraction_1.0.0.yaml",
  "modelToUpload": {
    "displayName": " MODEL_DISPLAY_NAME 
"
  },
  "inputDataConfig": {
    "datasetId": " DATASET_ID 
"
  }
}

To send your request, expand one of these options:

You should receive a JSON response similar to the following:

{
  "name": "projects/ PROJECT_NUMBER 
/locations/us-central1/trainingPipelines/ PIPELINE_ID 
",
  "displayName": " MODEL_DISPLAY_NAME 
",
  "inputDataConfig": {
    "datasetId": " DATASET_ID 
"
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_extraction_1.0.0.yaml",
  "modelToUpload": {
    "displayName": " MODEL_DISPLAY_NAME 
"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-04-18T01:22:57.479336Z",
  "updateTime": "2020-04-18T01:22:57.479336Z"
}

Java

Before trying this sample, follow the Java setup instructions in the Vertex AI quickstart using client libraries . For more information, see the Vertex AI Java API reference documentation .

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment .

  import 
  
 com.google.cloud.aiplatform.util. ValueConverter 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. DeployedModelRef 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. EnvVar 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. FilterSplit 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. FractionSplit 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. InputDataConfig 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. LocationName 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. Model 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. Model 
. ExportFormat 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. ModelContainerSpec 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. PipelineServiceClient 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. PipelineServiceSettings 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. Port 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. PredefinedSplit 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. PredictSchemata 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. TimestampSplit 
 
 ; 
 import 
  
 com.google.cloud.aiplatform.v1. TrainingPipeline 
 
 ; 
 import 
  
 com.google.rpc. Status 
 
 ; 
 import 
  
 java.io.IOException 
 ; 
 public 
  
 class 
 CreateTrainingPipelineTextEntityExtractionSample 
  
 { 
  
 public 
  
 static 
  
 void 
  
 main 
 ( 
 String 
 [] 
  
 args 
 ) 
  
 throws 
  
 IOException 
  
 { 
  
 // TODO(developer): Replace these variables before running the sample. 
  
 String 
  
 trainingPipelineDisplayName 
  
 = 
  
 "YOUR_TRAINING_PIPELINE_DISPLAY_NAME" 
 ; 
  
 String 
  
 project 
  
 = 
  
 "YOUR_PROJECT_ID" 
 ; 
  
 String 
  
 datasetId 
  
 = 
  
 "YOUR_DATASET_ID" 
 ; 
  
 String 
  
 modelDisplayName 
  
 = 
  
 "YOUR_MODEL_DISPLAY_NAME" 
 ; 
  
 createTrainingPipelineTextEntityExtractionSample 
 ( 
  
 project 
 , 
  
 trainingPipelineDisplayName 
 , 
  
 datasetId 
 , 
  
 modelDisplayName 
 ); 
  
 } 
  
 static 
  
 void 
  
 createTrainingPipelineTextEntityExtractionSample 
 ( 
  
 String 
  
 project 
 , 
  
 String 
  
 trainingPipelineDisplayName 
 , 
  
 String 
  
 datasetId 
 , 
  
 String 
  
 modelDisplayName 
 ) 
  
 throws 
  
 IOException 
  
 { 
  
  PipelineServiceSettings 
 
  
 pipelineServiceSettings 
  
 = 
  
  PipelineServiceSettings 
 
 . 
 newBuilder 
 () 
  
 . 
 setEndpoint 
 ( 
 "us-central1-aiplatform.googleapis.com:443" 
 ) 
  
 . 
 build 
 (); 
  
 // Initialize client that will be used to send requests. This client only needs to be created 
  
 // once, and can be reused for multiple requests. After completing all of your requests, call 
  
 // the "close" method on the client to safely clean up any remaining background resources. 
  
 try 
  
 ( 
  PipelineServiceClient 
 
  
 pipelineServiceClient 
  
 = 
  
  PipelineServiceClient 
 
 . 
 create 
 ( 
 pipelineServiceSettings 
 )) 
  
 { 
  
 String 
  
 location 
  
 = 
  
 "us-central1" 
 ; 
  
 String 
  
 trainingTaskDefinition 
  
 = 
  
 "gs://google-cloud-aiplatform/schema/trainingjob/definition/" 
  
 + 
  
 "automl_text_extraction_1.0.0.yaml" 
 ; 
  
  LocationName 
 
  
 locationName 
  
 = 
  
  LocationName 
 
 . 
 of 
 ( 
 project 
 , 
  
 location 
 ); 
  
  InputDataConfig 
 
  
 trainingInputDataConfig 
  
 = 
  
  InputDataConfig 
 
 . 
 newBuilder 
 (). 
  setDatasetId 
 
 ( 
 datasetId 
 ). 
 build 
 (); 
  
  Model 
 
  
 model 
  
 = 
  
  Model 
 
 . 
 newBuilder 
 (). 
 setDisplayName 
 ( 
 modelDisplayName 
 ). 
 build 
 (); 
  
  TrainingPipeline 
 
  
 trainingPipeline 
  
 = 
  
  TrainingPipeline 
 
 . 
 newBuilder 
 () 
  
 . 
 setDisplayName 
 ( 
 trainingPipelineDisplayName 
 ) 
  
 . 
  setTrainingTaskDefinition 
 
 ( 
 trainingTaskDefinition 
 ) 
  
 . 
  setTrainingTaskInputs 
 
 ( 
  ValueConverter 
 
 . 
 EMPTY_VALUE 
 ) 
  
 . 
  setInputDataConfig 
 
 ( 
 trainingInputDataConfig 
 ) 
  
 . 
  setModelToUpload 
 
 ( 
 model 
 ) 
  
 . 
 build 
 (); 
  
  TrainingPipeline 
 
  
 trainingPipelineResponse 
  
 = 
  
 pipelineServiceClient 
 . 
 createTrainingPipeline 
 ( 
 locationName 
 , 
  
 trainingPipeline 
 ); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "Create Training Pipeline Text Entity Extraction Response" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\tName: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getName 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\tDisplay Name: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getDisplayName 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "\tTraining Task Definition %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getTrainingTaskDefinition 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "\tTraining Task Inputs: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getTrainingTaskInputs 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "\tTraining Task Metadata: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getTrainingTaskMetadata 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "State: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getState 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\tCreate Time: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getCreateTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\tStartTime %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getStartTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\tEnd Time: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getEndTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\tUpdate Time: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getUpdateTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\tLabels: %s\n" 
 , 
  
 trainingPipelineResponse 
 . 
  getLabelsMap 
 
 ()); 
  
  InputDataConfig 
 
  
 inputDataConfig 
  
 = 
  
 trainingPipelineResponse 
 . 
  getInputDataConfig 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\tInput Data Config" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tDataset Id: %s" 
 , 
  
 inputDataConfig 
 . 
  getDatasetId 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tAnnotations Filter: %s\n" 
 , 
  
 inputDataConfig 
 . 
  getAnnotationsFilter 
 
 ()); 
  
  FractionSplit 
 
  
 fractionSplit 
  
 = 
  
 inputDataConfig 
 . 
  getFractionSplit 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\tFraction Split" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tTraining Fraction: %s\n" 
 , 
  
 fractionSplit 
 . 
  getTrainingFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tValidation Fraction: %s\n" 
 , 
  
 fractionSplit 
 . 
  getValidationFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tTest Fraction: %s\n" 
 , 
  
 fractionSplit 
 . 
  getTestFraction 
 
 ()); 
  
  FilterSplit 
 
  
 filterSplit 
  
 = 
  
 inputDataConfig 
 . 
  getFilterSplit 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\tFilter Split" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tTraining Filter: %s\n" 
 , 
  
 filterSplit 
 . 
  getTrainingFilter 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tValidation Filter: %s\n" 
 , 
  
 filterSplit 
 . 
  getValidationFilter 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tTest Filter: %s\n" 
 , 
  
 filterSplit 
 . 
  getTestFilter 
 
 ()); 
  
  PredefinedSplit 
 
  
 predefinedSplit 
  
 = 
  
 inputDataConfig 
 . 
  getPredefinedSplit 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\tPredefined Split" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tKey: %s\n" 
 , 
  
 predefinedSplit 
 . 
  getKey 
 
 ()); 
  
  TimestampSplit 
 
  
 timestampSplit 
  
 = 
  
 inputDataConfig 
 . 
  getTimestampSplit 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\tTimestamp Split" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tTraining Fraction: %s\n" 
 , 
  
 timestampSplit 
 . 
  getTrainingFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tValidation Fraction: %s\n" 
 , 
  
 timestampSplit 
 . 
  getValidationFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tTest Fraction: %s\n" 
 , 
  
 timestampSplit 
 . 
  getTestFraction 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tKey: %s\n" 
 , 
  
 timestampSplit 
 . 
  getKey 
 
 ()); 
  
  Model 
 
  
 modelResponse 
  
 = 
  
 trainingPipelineResponse 
 . 
  getModelToUpload 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\tModel To Upload" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tName: %s\n" 
 , 
  
 modelResponse 
 . 
  getName 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tDisplay Name: %s\n" 
 , 
  
 modelResponse 
 . 
  getDisplayName 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tDescription: %s\n" 
 , 
  
 modelResponse 
 . 
  getDescription 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tMetadata Schema Uri: %s\n" 
 , 
  
 modelResponse 
 . 
  getMetadataSchemaUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tMetadata: %s\n" 
 , 
  
 modelResponse 
 . 
  getMetadata 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tTraining Pipeline: %s\n" 
 , 
  
 modelResponse 
 . 
  getTrainingPipeline 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tArtifact Uri: %s\n" 
 , 
  
 modelResponse 
 . 
  getArtifactUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "\t\tSupported Deployment Resources Types: %s\n" 
 , 
  
 modelResponse 
 . 
  getSupportedDeploymentResourcesTypesList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "\t\tSupported Input Storage Formats: %s\n" 
 , 
  
 modelResponse 
 . 
  getSupportedInputStorageFormatsList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "\t\tSupported Output Storage Formats: %s\n" 
 , 
  
 modelResponse 
 . 
  getSupportedOutputStorageFormatsList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tCreate Time: %s\n" 
 , 
  
 modelResponse 
 . 
  getCreateTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tUpdate Time: %s\n" 
 , 
  
 modelResponse 
 . 
  getUpdateTime 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tLabels: %sn\n" 
 , 
  
 modelResponse 
 . 
  getLabelsMap 
 
 ()); 
  
  PredictSchemata 
 
  
 predictSchemata 
  
 = 
  
 modelResponse 
 . 
  getPredictSchemata 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\tPredict Schemata" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tInstance Schema Uri: %s\n" 
 , 
  
 predictSchemata 
 . 
  getInstanceSchemaUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "\t\t\tParameters Schema Uri: %s\n" 
 , 
  
 predictSchemata 
 . 
  getParametersSchemaUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
  
 "\t\t\tPrediction Schema Uri: %s\n" 
 , 
  
 predictSchemata 
 . 
  getPredictionSchemaUri 
 
 ()); 
  
 for 
  
 ( 
  ExportFormat 
 
  
 exportFormat 
  
 : 
  
 modelResponse 
 . 
  getSupportedExportFormatsList 
 
 ()) 
  
 { 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\tSupported Export Format" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tId: %s\n" 
 , 
  
 exportFormat 
 . 
 getId 
 ()); 
  
 } 
  
  ModelContainerSpec 
 
  
 modelContainerSpec 
  
 = 
  
 modelResponse 
 . 
  getContainerSpec 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\tContainer Spec" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tImage Uri: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getImageUri 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tCommand: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getCommandList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tArgs: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getArgsList 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tPredict Route: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getPredictRoute 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tHealth Route: %s\n" 
 , 
  
 modelContainerSpec 
 . 
  getHealthRoute 
 
 ()); 
  
 for 
  
 ( 
  EnvVar 
 
  
 envVar 
  
 : 
  
 modelContainerSpec 
 . 
  getEnvList 
 
 ()) 
  
 { 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\t\tEnv" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\t\tName: %s\n" 
 , 
  
 envVar 
 . 
 getName 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\t\tValue: %s\n" 
 , 
  
 envVar 
 . 
 getValue 
 ()); 
  
 } 
  
 for 
  
 ( 
  Port 
 
  
 port 
  
 : 
  
 modelContainerSpec 
 . 
  getPortsList 
 
 ()) 
  
 { 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\t\tPort" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\t\tContainer Port: %s\n" 
 , 
  
 port 
 . 
 getContainerPort 
 ()); 
  
 } 
  
 for 
  
 ( 
  DeployedModelRef 
 
  
 deployedModelRef 
  
 : 
  
 modelResponse 
 . 
  getDeployedModelsList 
 
 ()) 
  
 { 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\t\tDeployed Model" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tEndpoint: %s\n" 
 , 
  
 deployedModelRef 
 . 
 getEndpoint 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\t\tDeployed Model Id: %s\n" 
 , 
  
 deployedModelRef 
 . 
 getDeployedModelId 
 ()); 
  
 } 
  
  Status 
 
  
 status 
  
 = 
  
 trainingPipelineResponse 
 . 
  getError 
 
 (); 
  
 System 
 . 
 out 
 . 
 println 
 ( 
 "\tError" 
 ); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tCode: %s\n" 
 , 
  
 status 
 . 
  getCode 
 
 ()); 
  
 System 
 . 
 out 
 . 
 format 
 ( 
 "\t\tMessage: %s\n" 
 , 
  
 status 
 . 
  getMessage 
 
 ()); 
  
 } 
  
 } 
 } 
 

Node.js

Before trying this sample, follow the Node.js setup instructions in the Vertex AI quickstart using client libraries . For more information, see the Vertex AI Node.js API reference documentation .

To authenticate to Vertex AI, set up Application Default Credentials. For more information, see Set up authentication for a local development environment .

  /** 
 * TODO(developer): Uncomment these variables before running the sample.\ 
 * (Not necessary if passing values as arguments) 
 */ 
 // const datasetId = 'YOUR_DATASET_ID'; 
 // const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME'; 
 // const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME'; 
 // const project = 'YOUR_PROJECT_ID'; 
 // const location = 'YOUR_PROJECT_LOCATION'; 
 const 
  
 aiplatform 
  
 = 
  
 require 
 ( 
 ' @google-cloud/aiplatform 
' 
 ); 
 const 
  
 { 
 definition 
 } 
  
 = 
  
 aiplatform 
 . 
 protos 
 . 
 google 
 . 
 cloud 
 . 
 aiplatform 
 . 
 v1 
 . 
 schema 
 . 
 trainingjob 
 ; 
 // Imports the Google Cloud Pipeline Service Client library 
 const 
  
 { 
 PipelineServiceClient 
 } 
  
 = 
  
 aiplatform 
 . 
 v1 
 ; 
 // Specifies the location of the api endpoint 
 const 
  
 clientOptions 
  
 = 
  
 { 
  
 apiEndpoint 
 : 
  
 'us-central1-aiplatform.googleapis.com' 
 , 
 }; 
 // Instantiates a client 
 const 
  
 pipelineServiceClient 
  
 = 
  
 new 
  
  PipelineServiceClient 
 
 ( 
 clientOptions 
 ); 
 async 
  
 function 
  
 createTrainingPipelineTextEntityExtraction 
 () 
  
 { 
  
 // Configure the parent resource 
  
 const 
  
 parent 
  
 = 
  
 `projects/ 
 ${ 
 project 
 } 
 /locations/ 
 ${ 
 location 
 } 
 ` 
 ; 
  
 const 
  
 trainingTaskInputObj 
  
 = 
  
 new 
  
 definition 
 . 
 AutoMlTextExtractionInputs 
 ({}); 
  
 const 
  
 trainingTaskInputs 
  
 = 
  
 trainingTaskInputObj 
 . 
 toValue 
 (); 
  
 const 
  
 modelToUpload 
  
 = 
  
 { 
 displayName 
 : 
  
 modelDisplayName 
 }; 
  
 const 
  
 inputDataConfig 
  
 = 
  
 { 
 datasetId 
 : 
  
 datasetId 
 }; 
  
 const 
  
 trainingPipeline 
  
 = 
  
 { 
  
 displayName 
 : 
  
 trainingPipelineDisplayName 
 , 
  
 trainingTaskDefinition 
 : 
  
 'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_extraction_1.0.0.yaml' 
 , 
  
 trainingTaskInputs 
 , 
  
 inputDataConfig 
 , 
  
 modelToUpload 
 , 
  
 }; 
  
 const 
  
 request 
  
 = 
  
 { 
  
 parent 
 , 
  
 trainingPipeline 
 , 
  
 }; 
  
 // Create training pipeline request 
  
 const 
  
 [ 
 response 
 ] 
  
 = 
  
 await 
  
 pipelineServiceClient 
 . 
 createTrainingPipeline 
 ( 
 request 
 ); 
  
 console 
 . 
 log 
 ( 
 'Create training pipeline text entity extraction response :' 
 ); 
  
 console 
 . 
 log 
 ( 
 `Name : 
 ${ 
 response 
 . 
 name 
 } 
 ` 
 ); 
  
 console 
 . 
 log 
 ( 
 'Raw response:' 
 ); 
  
 console 
 . 
 log 
 ( 
 JSON 
 . 
 stringify 
 ( 
 response 
 , 
  
 null 
 , 
  
 2 
 )); 
 } 
 createTrainingPipelineTextEntityExtraction 
 (); 
 

Python

To learn how to install or update the Vertex AI SDK for Python, see Install the Vertex AI SDK for Python . For more information, see the Python API reference documentation .

  def 
  
 create_training_pipeline_text_entity_extraction_sample 
 ( 
  
 project 
 : 
  
 str 
 , 
  
 location 
 : 
  
 str 
 , 
  
 display_name 
 : 
  
 str 
 , 
  
 dataset_id 
 : 
  
 str 
 , 
  
 model_display_name 
 : 
  
 Optional 
 [ 
 str 
 ] 
  
 = 
  
 None 
 , 
  
 training_fraction_split 
 : 
  
 float 
  
 = 
  
 0.8 
 , 
  
 validation_fraction_split 
 : 
  
 float 
  
 = 
  
 0.1 
 , 
  
 test_fraction_split 
 : 
  
 float 
  
 = 
  
 0.1 
 , 
  
 budget_milli_node_hours 
 : 
  
 int 
  
 = 
  
 8000 
 , 
  
 disable_early_stopping 
 : 
  
 bool 
  
 = 
  
 False 
 , 
  
 sync 
 : 
  
 bool 
  
 = 
  
 True 
 , 
 ) 
 : 
  
 aiplatform 
 . 
 init 
 ( 
 project 
 = 
 project 
 , 
  
 location 
 = 
 location 
 ) 
  
 job 
  
 = 
  
 aiplatform 
 . 
 AutoMLTextTrainingJob 
 ( 
  
 display_name 
 = 
 display_name 
 , 
  
 prediction_type 
 = 
 "extraction" 
  
 ) 
  
 text_dataset 
  
 = 
  
 aiplatform 
 . 
 TextDataset 
 ( 
 dataset_id 
 ) 
  
 model 
  
 = 
  
 job 
 . 
 run 
 ( 
  
 dataset 
 = 
 text_dataset 
 , 
  
 model_display_name 
 = 
 model_display_name 
 , 
  
 training_fraction_split 
 = 
 training_fraction_split 
 , 
  
 validation_fraction_split 
 = 
 validation_fraction_split 
 , 
  
 test_fraction_split 
 = 
 test_fraction_split 
 , 
  
 budget_milli_node_hours 
 = 
 budget_milli_node_hours 
 , 
  
 disable_early_stopping 
 = 
 disable_early_stopping 
 , 
  
 sync 
 = 
 sync 
 , 
  
 ) 
  
 model 
 . 
 wait 
 () 
  
 print 
 ( 
 model 
 . 
 display_name 
 ) 
  
 print 
 ( 
 model 
 . 
 resource_name 
 ) 
  
 print 
 ( 
 model 
 . 
 uri 
 ) 
  
 return 
  
 model 
 

Control the data split using REST

You can control how your training data is split between the training, validation, and test sets. When using the Vertex AI API, use the Split object to determine your data split. The Split object can be included in the InputConfig object as one of several object types, each of which provides a different way to split the training data. You can select one method only.

  • FractionSplit :
    • TRAINING_FRACTION : The fraction of the training data to be used for the training set.
    • VALIDATION_FRACTION : The fraction of the training data to be used for the validation set. Not used for video data.
    • TEST_FRACTION : The fraction of the training data to be used for the test set.

    If any of the fractions are specified, all must be specified. The fractions must add up to 1.0. The default values for the fractions differ depending on your data type. Learn more .

    "fractionSplit": {
      "trainingFraction": TRAINING_FRACTION 
    ,
      "validationFraction": VALIDATION_FRACTION 
    ,
      "testFraction": TEST_FRACTION 
    },
  • FilterSplit :
    • TRAINING_FILTER : Data items that match this filter are used for the training set.
    • VALIDATION_FILTER : Data items that match this filter are used for the validation set. Must be "-" for video data.
    • TEST_FILTER : Data items that match this filter are used for the test set.

    These filters can be used with the ml_use label, or with any labels you apply to your data. Learn more about using the ml-use label and other labels to filter your data.

    The following example shows how to use the filterSplit object with the ml_use label, with the validation set included:

    "filterSplit": {
    "trainingFilter": "labels.aiplatform.googleapis.com/ml_use=training",
    "validationFilter": "labels.aiplatform.googleapis.com/ml_use=validation",
    "testFilter": "labels.aiplatform.googleapis.com/ml_use=test"
    }
Create a Mobile Website
View Site in Mobile | Classic
Share by: