31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
49 .Case([](arith::AddFOp op) {
return ReductionKind::Sum; })
50 .Case([](arith::MulFOp op) {
return ReductionKind::Product; })
52 .Case([](arith::MaximumFOp op) {
return ReductionKind::Max; })
53 .Case([](arith::MinimumFOp op) {
return ReductionKind::Min; })
55 .Case([](arith::AddIOp op) {
return ReductionKind::Sum; })
56 .Case([](arith::OrIOp op) {
return ReductionKind::BitwiseOr; })
57 .Case([](arith::XOrIOp op) {
return ReductionKind::BitwiseXor; })
58 .Case([](arith::AndIOp op) {
return ReductionKind::Sum; })
68 .Case([](arith::MaxUIOp op) {
return ReductionKind::Max; })
69 .Case([](arith::MinUIOp op) {
return ReductionKind::Min; })
70 .Case([](arith::MaxSIOp op) {
return ReductionKind::Max; })
71 .Case([](arith::MinSIOp op) {
return ReductionKind::Min; })
72 .Case([](arith::MulIOp op) {
return ReductionKind::Product; })
73 .Default([](
Operation *op) {
return ReductionKind::Generic; });
79 if (!reducedValue || combinerOps.size() != 1) {
83 return combinerOps[0];
89 return ReductionKind::Generic;
91 [[maybe_unused]]
Type resultElementType =
92 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
96 assert(resultElementType == reductionOp.value()->getResult(0).getType());
103 for (
const Sharding &sharding : operandShardings) {
109 for (
const Sharding &sharding : resultShardings) {
127 LinalgOp op,
int operandNumber,
Value partitionedOperand,
131 gridOp.getSymName(), reductionGridAxes, builder);
133 Value isLeadProcess = arith::CmpIOp::create(
134 builder, builder.
getI1Type(), arith::CmpIPredicate::eq,
135 processLinearIndexInReductionGroup, zero);
136 scf::IfOp ifOp = scf::IfOp::create(builder, partitionedOperand.
getType(),
137 isLeadProcess,
true,
true);
142 scf::YieldOp::create(builder, partitionedOperand);
153 matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
154 assert(combinerOps.size() == 1);
155 std::optional<TypedAttr> neutralEl =
158 Value init = tensor::EmptyOp::create(builder, op.
getLoc(), shape,
159 neutralEl.value().getType());
161 arith::ConstantOp::create(builder, op.
getLoc(), neutralEl.value());
162 Value fill = linalg::FillOp::create(builder, op.
getLoc(), constant, init)
165 scf::YieldOp::create(builder, fill);
167 return ifOp.getResult(0);
178 assert(op.getNumDpsInits() == 1 &&
"Multiple initial values not supported.");
180 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
181 Value partitionedInitOperand =
182 partitionMap.
lookup(op->getOperands()[operandIdx]);
184 op, 0, partitionedInitOperand, reductionGridAxes, gridOp, builder);
193 for (
auto [unshardedLinalgOpResult, resultSharding] :
194 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
195 Value partitionedLinalgOpResult =
196 partitionMap.
lookup(unshardedLinalgOpResult);
197 Value reducedValue = shard::AllReduceOp::create(
198 builder, partitionedLinalgOpResult, resultSharding.getGrid(),
199 opReductionGridAxes, reductionKind);
200 partitionMap.
map(unshardedLinalgOpResult, reducedValue);
211 GridOp grid =
getGrid(op, operandShardings, resultShardings, symbolTable);
213 loopIteratorTypes, gridAxisAssignmentForLoopIterators);
216 reductionGridAxes, partitionMap,
222 for (
auto [unshardedOperand, partitionedOperand] :
223 llvm::zip_equal(op->getOperands(), partitionedLinalgOpOperands)) {
224 internalPartitionMap.
map(unshardedOperand, partitionedOperand);
227 *op, partitionedLinalgOpOperands, operandShardings, resultShardings,
228 internalPartitionMap, symbolTable, builder);
229 for (
Value result : op->getResults()) {
230 partitionMap.
map(result, internalPartitionMap.
lookup(result));
235 op, reductionGridAxes, resultShardings, partitionMap, builder);
243 template <
typename Op>
244 struct StructuredOpShardingInterface
245 :
public shard::ShardingInterface::ExternalModel<
246 StructuredOpShardingInterface<Op>, Op> {
248 return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
251 SmallVector<AffineMap> getIndexingMaps(Operation *op)
const {
252 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
253 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
257 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
258 res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
264 SmallVector<ReductionKind>
265 getReductionLoopIteratorKinds(Operation *op)
const {
266 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
267 SmallVector<utils::IteratorType> iteratorTypes =
268 linalgOp.getIteratorTypesArray();
269 unsigned reductionItersCount = llvm::accumulate(
270 iteratorTypes, 0u, [](
unsigned count, utils::IteratorType iter) {
271 return count + (iter == utils::IteratorType::reduction);
274 return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
277 LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
278 ArrayRef<Sharding> operandShardings,
279 ArrayRef<Sharding> resultShardings,
280 IRMapping &partitionMap,
281 SymbolTableCollection &symbolTable,
282 OpBuilder &builder)
const {
283 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
285 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
286 bool allIndexingMapsAreProjectedPermutation =
287 llvm::all_of(indexingMaps, [](AffineMap map) {
288 return map.isProjectedPermutation();
290 if (!allIndexingMapsAreProjectedPermutation) {
292 return op->emitOpError()
293 <<
"supports indexing maps that are only projected permutation.";
296 SmallVector<utils::IteratorType> loopIteratorTypes =
297 linalgOp.getIteratorTypesArray();
300 loopIteratorTypes, indexingMaps);
302 loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
303 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
305 linalgOp, partitionedOperands, operandShardings, resultShardings,
306 loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap,
307 symbolTable, implicitLocBuilder);
310 operandShardings, resultShardings,
311 partitionMap, symbolTable, builder);
320 template <
typename OpType>
322 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
326 template <
typename... OpTypes>
328 (registerOne<OpTypes>(ctx), ...);
334 registry.
insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
335 tensor::TensorDialect>();
340 registerOne<linalg::GenericOp>(ctx);
343 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static Value createDestinationPassingStyleInitOperand(LinalgOp op, int operandNumber, Value partitionedOperand, ArrayRef< GridAxis > reductionGridAxes, GridOp gridOp, ImplicitLocOpBuilder &builder)
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< GridAxis > opReductionGridAxes, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
shard::ShardingArray ShardingArray
shard::ReductionKind ReductionKind
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, GridOp gridOp, ArrayRef< Value > partitionedOperands, ArrayRef< GridAxis > reductionGridAxes, IRMapping &partitionMap, ImplicitLocOpBuilder &builder)
static void partitionLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators, IRMapping &partitionMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
void registerShardingInterfaceExternalModels(DialectRegistry ®istry)
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
static GridOp getGrid(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, SymbolTableCollection &symbolTable)
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
SmallVector< SmallVector< GridAxis > > ShardingArray
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...