33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/TypeSwitch.h"
54 .Case([](arith::AddFOp op) {
return ReductionKind::Sum; })
55 .Case([](arith::MulFOp op) {
return ReductionKind::Product; })
57 .Case([](arith::MaximumFOp op) {
return ReductionKind::Max; })
58 .Case([](arith::MinimumFOp op) {
return ReductionKind::Min; })
60 .Case([](arith::AddIOp op) {
return ReductionKind::Sum; })
61 .Case([](arith::OrIOp op) {
return ReductionKind::BitwiseOr; })
62 .Case([](arith::XOrIOp op) {
return ReductionKind::BitwiseXor; })
63 .Case([](arith::AndIOp op) {
return ReductionKind::Sum; })
73 .Case([](arith::MaxUIOp op) {
return ReductionKind::Max; })
74 .Case([](arith::MinUIOp op) {
return ReductionKind::Min; })
75 .Case([](arith::MaxSIOp op) {
return ReductionKind::Max; })
76 .Case([](arith::MinSIOp op) {
return ReductionKind::Min; })
77 .Case([](arith::MulIOp op) {
return ReductionKind::Product; })
78 .Default([](
Operation *op) {
return ReductionKind::Generic; });
84 if (!reducedValue || combinerOps.size() != 1) {
88 return combinerOps[0];
94 return ReductionKind::Generic;
96 [[maybe_unused]]
Type resultElementType =
97 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
101 assert(resultElementType == reductionOp.value()->getResult(0).getType());
110 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
116 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
135 meshOp.getSymName(), reductionMeshAxes, builder);
137 Value isLeadProcess = builder.
create<arith::CmpIOp>(
138 builder.
getI1Type(), arith::CmpIPredicate::eq,
139 processLinearIndexInReductionGroup, zero);
140 scf::IfOp ifOp = builder.
create<scf::IfOp>(spmdizedOperand.
getType(),
141 isLeadProcess,
true,
true);
146 builder.
create<scf::YieldOp>(spmdizedOperand);
155 PartialReductionOpInterface partialReductionIface =
156 llvm::cast<PartialReductionOpInterface>(op.getOperation());
157 assert(op->getNumResults() == 1 &&
"Multiple results not supported.");
158 FailureOr<SmallVector<Value>> reductionNeutralTensor =
159 partialReductionIface.generateInitialTensorForPartialReduction(
160 builder, builder.
getLoc(), shape, {});
161 assert(succeeded(reductionNeutralTensor));
162 builder.
create<scf::YieldOp>(reductionNeutralTensor.value());
164 return ifOp.getResult(0);
175 assert(op.getNumDpsInits() == 1 &&
"Multiple initial values not supported.");
177 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
178 Value spmdizedInitOperand =
179 spmdizationMap.
lookup(op->getOperands()[operandIdx]);
181 op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
190 llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
195 if (allReduceMeshAxes.empty()) {
199 Value spmdizedLinalgOpResult = spmdizationMap.
lookup(unshardedLinalgOpResult);
200 Value reducedValue = builder.
create<mesh::AllReduceOp>(
201 spmdizedLinalgOpResult, resultSharding.
getMesh(), allReduceMeshAxes,
203 spmdizationMap.
map(unshardedLinalgOpResult, reducedValue);
211 for (
auto [unshardedLinalgOpResult, resultSharding] :
212 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
214 unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
215 reductionKind, spmdizationMap, builder);
227 MeshOp mesh =
getMesh(op, operandShardings, resultShardings, symbolTable);
229 loopIteratorTypes, meshAxisAssignmentForLoopIterators);
233 spmdizationMap, builder);
238 for (
auto [unshardedOperand, spmdizedOperand] :
239 llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
240 internalSpmdizationMap.
map(unshardedOperand, spmdizedOperand);
243 *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
244 internalSpmdizationMap, symbolTable, builder);
245 for (
Value result : op->getResults()) {
246 spmdizationMap.
map(result, internalSpmdizationMap.
lookup(result));
251 op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
259 template <
typename Op>
260 struct StructuredOpShardingInterface
261 :
public mesh::ShardingInterface::ExternalModel<
262 StructuredOpShardingInterface<Op>, Op> {
264 return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
267 SmallVector<AffineMap> getIndexingMaps(Operation *op)
const {
268 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
269 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
273 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
274 res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
280 SmallVector<ReductionKind>
281 getReductionLoopIteratorKinds(Operation *op)
const {
282 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
283 SmallVector<utils::IteratorType> iteratorTypes =
284 linalgOp.getIteratorTypesArray();
285 unsigned reductionItersCount = std::accumulate(
286 iteratorTypes.begin(), iteratorTypes.end(), 0,
287 [](
unsigned count, utils::IteratorType iter) {
288 return count + (iter == utils::IteratorType::reduction);
291 return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
294 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
295 ArrayRef<MeshSharding> operandShardings,
296 ArrayRef<MeshSharding> resultShardings,
297 IRMapping &spmdizationMap,
298 SymbolTableCollection &symbolTable,
299 OpBuilder &builder)
const {
300 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
302 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
303 bool allIndexingMapsAreProjectedPermutation =
304 llvm::all_of(indexingMaps, [](AffineMap map) {
305 return map.isProjectedPermutation();
307 if (!allIndexingMapsAreProjectedPermutation) {
309 return op->emitOpError()
310 <<
"supports indexing maps that are only projected permutation.";
313 SmallVector<utils::IteratorType> loopIteratorTypes =
314 linalgOp.getIteratorTypesArray();
317 loopIteratorTypes, indexingMaps);
319 loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
320 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
322 linalgOp, spmdizedOperands, operandShardings, resultShardings,
323 loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
324 symbolTable, implicitLocBuilder);
327 operandShardings, resultShardings,
328 spmdizationMap, symbolTable, builder);
337 template <
typename OpType>
339 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
343 template <
typename... OpTypes>
345 (registerOne<OpTypes>(ctx), ...);
351 registry.
insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
352 tensor::TensorDialect>();
357 registerOne<linalg::GenericOp>(ctx);
360 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
ArrayRef< MeshAxis > getPartialAxes() const
::llvm::StringRef getMesh() const
static MeshOp getMesh(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, SymbolTableCollection &symbolTable)
mesh::ReductionKind ReductionKind
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
static void createAllReduceForResultWithoutPartialSharding(Value unshardedLinalgOpResult, ArrayRef< MeshAxis > opReductionMeshAxes, MeshSharding resultSharding, ReductionKind reductionKind, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
mesh::MeshSharding MeshSharding
static Value createDestinationPassingStyleInitOperand(LinalgOp op, Value spmdizedOperand, ArrayRef< MeshAxis > reductionMeshAxes, MeshOp meshOp, ImplicitLocOpBuilder &builder)
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, MeshOp meshOp, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshAxis > reductionMeshAxes, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static void spmdizeLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
mesh::ShardingArray ShardingArray
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< MeshAxis > opReductionMeshAxes, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...