MLIR  20.0.0git
ShardingInterfaceImpl.h
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
11 
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Value.h"
16 
17 namespace mlir {
18 
19 class Operation;
20 class IRMapping;
21 class SymbolTableCollection;
22 
23 namespace mesh {
24 
25 // Retrieve the mesh axes corresponding to each operation loop iterator based
26 // on the provided shardings for the op's operands and results.
27 // Assumes that the indexingMaps are projected permutations.
29  ArrayRef<MeshSharding> operandShardings,
30  ArrayRef<MeshSharding> resultShardings,
31  ArrayRef<utils::IteratorType> loopIteratorTypes,
32  ArrayRef<AffineMap> indexingMaps);
33 
35  ArrayRef<utils::IteratorType> loopIteratorTypes,
36  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
37 
38 // Get the set of mesh axes that correspond to reduction loop iterators.
39 SmallVector<MeshAxis> getReductionMeshAxes(
40  ArrayRef<utils::IteratorType> loopIteratorTypes,
41  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
42 
43 // Inserts a clone of the operation that has all ranked tensor
44 // arguments/results sharded.
45 void spmdizeTriviallyShardableOperation(Operation &op,
46  ArrayRef<Value> spmdizedOperands,
47  ArrayRef<MeshSharding> operandShardings,
48  ArrayRef<MeshSharding> resultShardings,
49  IRMapping &spmdizationMap,
50  SymbolTableCollection &symbolTable,
51  OpBuilder &builder);
52 
53 // All ranked tensor argument and result dimensions have
54 // independent parallel loop iterators.
55 template <typename Op>
57  : public ShardingInterface::ExternalModel<
58  IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
60  getLoopIteratorTypes(Operation *operation) const {
62  for (Type t : operation->getOperandTypes()) {
63  populateIteratorTypes(t, iterTypes);
64  }
65  for (Type t : operation->getResultTypes()) {
66  populateIteratorTypes(t, iterTypes);
67  }
68  return iterTypes;
69  }
70 
72  // TODO: implement.
73  return SmallVector<AffineMap>();
74  }
75 
76  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
77  ArrayRef<MeshSharding> operandShardings,
78  ArrayRef<MeshSharding> resultShardings,
79  IRMapping &spmdizationMap,
80  SymbolTableCollection &symbolTable,
81  OpBuilder &builder) const {
82  spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
83  resultShardings, spmdizationMap,
84  symbolTable, builder);
85  return success();
86  }
87 
88 private:
89  void
90  populateIteratorTypes(Type t,
91  SmallVector<utils::IteratorType> &iterTypes) const {
92  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
93  if (!rankedTensorType) {
94  return;
95  }
96 
97  iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
98  for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
99  iterTypes.push_back(utils::IteratorType::parallel);
100  }
101  }
102 };
103 
104 // Sharding of elementwise operations like tensor addition and multiplication.
105 template <typename ElemwiseOp>
107  : public ShardingInterface::ExternalModel<
108  ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
110  Value val = op->getOperand(0);
111  auto type = dyn_cast<RankedTensorType>(val.getType());
112  if (!type)
113  return {};
114  SmallVector<utils::IteratorType> types(type.getRank(),
115  utils::IteratorType::parallel);
116  return types;
117  }
118 
120  MLIRContext *ctx = op->getContext();
121  Value val = op->getOperand(0);
122  auto type = dyn_cast<RankedTensorType>(val.getType());
123  if (!type)
124  return {};
125  int64_t rank = type.getRank();
126  int64_t num = op->getNumOperands() + op->getNumResults();
127  SmallVector<AffineMap> maps(num,
129  return maps;
130  }
131 
132  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
133  ArrayRef<MeshSharding> operandShardings,
134  ArrayRef<MeshSharding> resultShardings,
135  IRMapping &spmdizationMap,
136  SymbolTableCollection &symbolTable,
137  OpBuilder &builder) const {
138  spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
139  resultShardings, spmdizationMap,
140  symbolTable, builder);
141  return success();
142  }
143 };
144 
145 } // namespace mesh
146 } // namespace mlir
147 
148 #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
Include the generated interface declarations.
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *op) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *operation) const