MLIR 22.0.0git
ShardingInterfaceImpl.h
Go to the documentation of this file.
1//===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
10#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
11
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Value.h"
16
17namespace mlir {
18
19class Operation;
20class IRMapping;
22
23namespace shard {
24
25// Retrieve the grid axes corresponding to each operation loop iterator based
26// on the provided shardings for the op's operands and results.
27// Assumes that the indexingMaps are projected permutations.
29 ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
30 ArrayRef<utils::IteratorType> loopIteratorTypes,
31 ArrayRef<AffineMap> indexingMaps);
32
34 ArrayRef<utils::IteratorType> loopIteratorTypes,
35 ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
36
37// Get the set of grid axes that correspond to reduction loop iterators.
38SmallVector<GridAxis> getReductionGridAxes(
39 ArrayRef<utils::IteratorType> loopIteratorTypes,
40 ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
41
42// Inserts a clone of the operation that has all ranked tensor
43// arguments/results sharded.
45 ArrayRef<Value> partitionedOperands,
46 ArrayRef<Sharding> operandShardings,
47 ArrayRef<Sharding> resultShardings,
48 IRMapping &partitionMap,
49 SymbolTableCollection &symbolTable,
50 OpBuilder &builder);
51
52// All ranked tensor argument and result dimensions have
53// independent parallel loop iterators.
54template <typename Op>
56 : public ShardingInterface::ExternalModel<
57 IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
59 getLoopIteratorTypes(Operation *operation) const {
61 for (Type t : operation->getOperandTypes()) {
62 populateIteratorTypes(t, iterTypes);
63 }
64 for (Type t : operation->getResultTypes()) {
65 populateIteratorTypes(t, iterTypes);
66 }
67 return iterTypes;
68 }
69
71 // TODO: implement.
73 }
74
75 LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
76 ArrayRef<Sharding> operandShardings,
77 ArrayRef<Sharding> resultShardings,
78 IRMapping &partitionMap,
79 SymbolTableCollection &symbolTable,
80 OpBuilder &builder) const {
81 partitionTriviallyShardableOperation(*op, partitionedOperands,
82 operandShardings, resultShardings,
83 partitionMap, symbolTable, builder);
84 return success();
85 }
86
87private:
88 void
89 populateIteratorTypes(Type t,
90 SmallVector<utils::IteratorType> &iterTypes) const {
91 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
92 if (!rankedTensorType) {
93 return;
94 }
95
96 iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
97 for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
98 iterTypes.push_back(utils::IteratorType::parallel);
99 }
100 }
101};
102
103// Sharding of elementwise operations like tensor addition and multiplication.
104template <typename ElemwiseOp>
106 : public ShardingInterface::ExternalModel<
107 ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
109 Value val = op->getOperand(0);
110 auto type = dyn_cast<RankedTensorType>(val.getType());
111 if (!type)
112 return {};
113 SmallVector<utils::IteratorType> types(type.getRank(),
114 utils::IteratorType::parallel);
115 return types;
116 }
117
119 MLIRContext *ctx = op->getContext();
120 Value val = op->getOperand(0);
121 auto type = dyn_cast<RankedTensorType>(val.getType());
122 if (!type)
123 return {};
124 int64_t rank = type.getRank();
125 int64_t num = op->getNumOperands() + op->getNumResults();
126 SmallVector<AffineMap> maps(num,
128 return maps;
129 }
130
131 LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
132 ArrayRef<Sharding> operandShardings,
133 ArrayRef<Sharding> resultShardings,
134 IRMapping &partitionMap,
135 SymbolTableCollection &symbolTable,
136 OpBuilder &builder) const {
137 partitionTriviallyShardableOperation(*op, partitionedOperands,
138 operandShardings, resultShardings,
139 partitionMap, symbolTable, builder);
140 return success();
141 }
142};
143
144} // namespace shard
145} // namespace mlir
146
147#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
return success()
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
unsigned getNumOperands()
Definition Operation.h:346
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
SmallVector< SmallVector< GridAxis > > ShardingArray
Include the generated interface declarations.
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *op) const
LogicalResult partition(Operation *op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
LogicalResult partition(Operation *op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *operation) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const