pkg/controllers/job/plugins/distributed-framework/tensorflow/tensorflow.go (144 lines of code) (raw):

/* Copyright 2021 The Volcano Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package tensorflow import ( "encoding/json" "flag" "fmt" "strconv" v1 "k8s.io/api/core/v1" "k8s.io/klog" batch "volcano.sh/apis/pkg/apis/batch/v1alpha1" jobhelpers "volcano.sh/volcano/pkg/controllers/job/helpers" pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface" ) const ( // DefaultPort defines default port for service DefaultPort = 2222 // TFConfig defines environment variables for TF TFConfig = "TF_CONFIG" ) type tensorflowPlugin struct { tfArguments []string Clientset pluginsinterface.PluginClientset psName string workerName string chiefName string evaluatorName string port int } // New creates tensorflow plugin. func New(client pluginsinterface.PluginClientset, arguments []string) pluginsinterface.PluginInterface { tp := tensorflowPlugin{tfArguments: arguments, Clientset: client} tp.addFlags() return &tp } func (tp *tensorflowPlugin) addFlags() { flagSet := flag.NewFlagSet(tp.Name(), flag.ContinueOnError) flagSet.StringVar(&tp.psName, "ps", "ps", "name of ps role task") flagSet.StringVar(&tp.workerName, "worker", "worker", "name of ps role task") flagSet.StringVar(&tp.chiefName, "chief", "chief", "name of chief role task") flagSet.StringVar(&tp.evaluatorName, "evaluator", "evaluator", "name of evaluator role task") flagSet.IntVar(&tp.port, "port", DefaultPort, "service port") if err := flagSet.Parse(tp.tfArguments); err != nil { klog.Errorf("plugin %s flagset parse failed, err: %v", tp.Name(), err) } } func (tp *tensorflowPlugin) Name() string { return "tensorflow" } func (tp *tensorflowPlugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error { // No need to generate TF_CONFIG for stand-alone tensorflow job if len(job.Spec.Tasks) == 1 && job.Spec.Tasks[0].Replicas == 1 { return nil } // Generate TF_CONFIG value spec, err := tp.generateTFClusterSpec(pod, job) if err != nil { return err } raw, err := json.Marshal(spec) if err != nil { return err } // Add TF_CONFIG enviroment variables for i := range pod.Spec.Containers { pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{ Name: TFConfig, Value: string(raw), }) } return nil } func (tp *tensorflowPlugin) OnJobAdd(job *batch.Job) error { if job.Status.ControlledResources["plugin-"+tp.Name()] == tp.Name() { return nil } job.Status.ControlledResources["plugin-"+tp.Name()] = tp.Name() return nil } func (tp *tensorflowPlugin) OnJobDelete(job *batch.Job) error { if job.Status.ControlledResources["plugin-"+tp.Name()] != tp.Name() { return nil } delete(job.Status.ControlledResources, "plugin-"+tp.Name()) return nil } func (tp *tensorflowPlugin) OnJobUpdate(job *batch.Job) error { return nil } func (tp *tensorflowPlugin) generateTFClusterSpec(pod *v1.Pod, job *batch.Job) (tfClusterSpec, error) { index, err := strconv.Atoi(jobhelpers.GetPodIndexUnderTask(pod)) if err != nil { return tfClusterSpec{}, err } // Generate tensorflow task info c := tfClusterSpec{ Task: taskInfo{ Type: tp.getTaskType(jobhelpers.GetTaskKey(pod)), Index: index, }, } // Generate tensorflow cluster info for _, ts := range job.Spec.Tasks { hosts := []string{} for i := 0; i < int(ts.Replicas); i++ { hosts = append(hosts, fmt.Sprintf("%s:%d", jobhelpers.MakeDomainName(ts, job, i), tp.port)) } switch ts.Name { case tp.psName: c.Cluster.PS = hosts case tp.workerName: c.Cluster.Worker = hosts case tp.chiefName: c.Cluster.Chief = hosts case tp.evaluatorName: c.Cluster.Evaluator = hosts } } return c, nil } func (tp *tensorflowPlugin) getTaskType(taskKey string) tfTaskType { switch taskKey { case tp.chiefName: return tfChief case tp.workerName: return tfWorker case tp.psName: return tfPS case tp.evaluatorName: return tfEvaluator } return tfTaskType(taskKey) } // TfClusterSpec is the spec of a tensorflow cluster // It will be injected into container's environment variables, and be used by tensorflow framework. // e.g. // { // "cluster": { // "worker": ["worker-0:2222", "worker-1:2222"], // "ps": ["ps-0:2222"] // }, // "task": { // "type": "worker", // "index": 0 // } // } type tfClusterSpec struct { Cluster clusterInfo `json:"cluster"` Task taskInfo `json:"task"` } type clusterInfo struct { PS []string `json:"ps,omitempty"` Worker []string `json:"worker,omitempty"` Chief []string `json:"chief,omitempty"` Evaluator []string `json:"evaluator,omitempty"` } type tfTaskType string const ( tfWorker tfTaskType = "worker" tfChief tfTaskType = "chief" tfPS tfTaskType = "ps" tfEvaluator tfTaskType = "evaluator" ) type taskInfo struct { Type tfTaskType `json:"type"` Index int `json:"index"` }