MLIR  22.0.0git
ShardingInterfaceImpl.h
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
10 #define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
11 
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Value.h"
16 
17 namespace mlir {
18 
19 class Operation;
20 class IRMapping;
21 class SymbolTableCollection;
22 
23 namespace shard {
24 
25 // Retrieve the grid axes corresponding to each operation loop iterator based
26 // on the provided shardings for the op's operands and results.
27 // Assumes that the indexingMaps are projected permutations.
29  ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
30  ArrayRef<utils::IteratorType> loopIteratorTypes,
31  ArrayRef<AffineMap> indexingMaps);
32 
34  ArrayRef<utils::IteratorType> loopIteratorTypes,
35  ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
36 
37 // Get the set of grid axes that correspond to reduction loop iterators.
38 SmallVector<GridAxis> getReductionGridAxes(
39  ArrayRef<utils::IteratorType> loopIteratorTypes,
40  ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
41 
42 // Inserts a clone of the operation that has all ranked tensor
43 // arguments/results sharded.
44 void partitionTriviallyShardableOperation(Operation &op,
45  ArrayRef<Value> partitionedOperands,
46  ArrayRef<Sharding> operandShardings,
47  ArrayRef<Sharding> resultShardings,
48  IRMapping &partitionMap,
49  SymbolTableCollection &symbolTable,
50  OpBuilder &builder);
51 
52 // All ranked tensor argument and result dimensions have
53 // independent parallel loop iterators.
54 template <typename Op>
56  : public ShardingInterface::ExternalModel<
57  IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
59  getLoopIteratorTypes(Operation *operation) const {
61  for (Type t : operation->getOperandTypes()) {
62  populateIteratorTypes(t, iterTypes);
63  }
64  for (Type t : operation->getResultTypes()) {
65  populateIteratorTypes(t, iterTypes);
66  }
67  return iterTypes;
68  }
69 
71  // TODO: implement.
72  return SmallVector<AffineMap>();
73  }
74 
75  LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
76  ArrayRef<Sharding> operandShardings,
77  ArrayRef<Sharding> resultShardings,
78  IRMapping &partitionMap,
79  SymbolTableCollection &symbolTable,
80  OpBuilder &builder) const {
81  partitionTriviallyShardableOperation(*op, partitionedOperands,
82  operandShardings, resultShardings,
83  partitionMap, symbolTable, builder);
84  return success();
85  }
86 
87 private:
88  void
89  populateIteratorTypes(Type t,
90  SmallVector<utils::IteratorType> &iterTypes) const {
91  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
92  if (!rankedTensorType) {
93  return;
94  }
95 
96  iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
97  for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
98  iterTypes.push_back(utils::IteratorType::parallel);
99  }
100  }
101 };
102 
103 // Sharding of elementwise operations like tensor addition and multiplication.
104 template <typename ElemwiseOp>
106  : public ShardingInterface::ExternalModel<
107  ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
109  Value val = op->getOperand(0);
110  auto type = dyn_cast<RankedTensorType>(val.getType());
111  if (!type)
112  return {};
113  SmallVector<utils::IteratorType> types(type.getRank(),
114  utils::IteratorType::parallel);
115  return types;
116  }
117 
119  MLIRContext *ctx = op->getContext();
120  Value val = op->getOperand(0);
121  auto type = dyn_cast<RankedTensorType>(val.getType());
122  if (!type)
123  return {};
124  int64_t rank = type.getRank();
125  int64_t num = op->getNumOperands() + op->getNumResults();
126  SmallVector<AffineMap> maps(num,
128  return maps;
129  }
130 
131  LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
132  ArrayRef<Sharding> operandShardings,
133  ArrayRef<Sharding> resultShardings,
134  IRMapping &partitionMap,
135  SymbolTableCollection &symbolTable,
136  OpBuilder &builder) const {
137  partitionTriviallyShardableOperation(*op, partitionedOperands,
138  operandShardings, resultShardings,
139  partitionMap, symbolTable, builder);
140  return success();
141  }
142 };
143 
144 } // namespace shard
145 } // namespace mlir
146 
147 #endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:346
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
SmallVector< SmallVector< GridAxis > > ShardingArray
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
Include the generated interface declarations.
LogicalResult partition(Operation *op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *op) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
LogicalResult partition(Operation *op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *operation) const