33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/TypeSwitch.h"
54 .Case([](arith::AddFOp op) {
return ReductionKind::Sum; })
55 .Case([](arith::MulFOp op) {
return ReductionKind::Product; })
57 .Case([](arith::MaximumFOp op) {
return ReductionKind::Max; })
58 .Case([](arith::MinimumFOp op) {
return ReductionKind::Min; })
60 .Case([](arith::AddIOp op) {
return ReductionKind::Sum; })
61 .Case([](arith::OrIOp op) {
return ReductionKind::BitwiseOr; })
62 .Case([](arith::XOrIOp op) {
return ReductionKind::BitwiseXor; })
63 .Case([](arith::AndIOp op) {
return ReductionKind::Sum; })
73 .Case([](arith::MaxUIOp op) {
return ReductionKind::Max; })
74 .Case([](arith::MinUIOp op) {
return ReductionKind::Min; })
75 .Case([](arith::MaxSIOp op) {
return ReductionKind::Max; })
76 .Case([](arith::MinSIOp op) {
return ReductionKind::Min; })
77 .Case([](arith::MulIOp op) {
return ReductionKind::Product; })
78 .Default([](
Operation *op) {
return ReductionKind::Generic; });
84 if (!reducedValue || combinerOps.size() != 1) {
88 return combinerOps[0];
94 return ReductionKind::Generic;
96 [[maybe_unused]]
Type resultElementType =
97 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
101 assert(resultElementType == reductionOp.value()->getResult(0).getType());
110 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
116 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
132 LinalgOp op,
int operandNumber,
Value spmdizedOperand,
136 meshOp.getSymName(), reductionMeshAxes, builder);
138 Value isLeadProcess = builder.
create<arith::CmpIOp>(
139 builder.
getI1Type(), arith::CmpIPredicate::eq,
140 processLinearIndexInReductionGroup, zero);
141 scf::IfOp ifOp = builder.
create<scf::IfOp>(spmdizedOperand.
getType(),
142 isLeadProcess,
true,
true);
147 builder.
create<scf::YieldOp>(spmdizedOperand);
158 matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
159 assert(combinerOps.size() == 1);
160 std::optional<TypedAttr> neutralEl =
163 Value init = builder.
create<tensor::EmptyOp>(op.getLoc(), shape,
164 neutralEl.value().getType());
166 builder.
create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
167 Value fill = builder.
create<linalg::FillOp>(op.getLoc(), constant, init)
170 builder.
create<scf::YieldOp>(fill);
172 return ifOp.getResult(0);
183 assert(op.getNumDpsInits() == 1 &&
"Multiple initial values not supported.");
185 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
186 Value spmdizedInitOperand =
187 spmdizationMap.
lookup(op->getOperands()[operandIdx]);
189 op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
198 llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
203 if (allReduceMeshAxes.empty()) {
207 Value spmdizedLinalgOpResult = spmdizationMap.
lookup(unshardedLinalgOpResult);
208 Value reducedValue = builder.
create<mesh::AllReduceOp>(
209 spmdizedLinalgOpResult, resultSharding.
getMesh(), allReduceMeshAxes,
211 spmdizationMap.
map(unshardedLinalgOpResult, reducedValue);
219 for (
auto [unshardedLinalgOpResult, resultSharding] :
220 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
222 unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
223 reductionKind, spmdizationMap, builder);
235 MeshOp mesh =
getMesh(op, operandShardings, resultShardings, symbolTable);
237 loopIteratorTypes, meshAxisAssignmentForLoopIterators);
241 spmdizationMap, builder);
246 for (
auto [unshardedOperand, spmdizedOperand] :
247 llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
248 internalSpmdizationMap.
map(unshardedOperand, spmdizedOperand);
251 *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
252 internalSpmdizationMap, symbolTable, builder);
253 for (
Value result : op->getResults()) {
254 spmdizationMap.
map(result, internalSpmdizationMap.
lookup(result));
259 op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
267 template <
typename Op>
268 struct StructuredOpShardingInterface
269 :
public mesh::ShardingInterface::ExternalModel<
270 StructuredOpShardingInterface<Op>, Op> {
272 return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
275 SmallVector<AffineMap> getIndexingMaps(Operation *op)
const {
276 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
277 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
281 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
282 res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
288 SmallVector<ReductionKind>
289 getReductionLoopIteratorKinds(Operation *op)
const {
290 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
291 SmallVector<utils::IteratorType> iteratorTypes =
292 linalgOp.getIteratorTypesArray();
293 unsigned reductionItersCount = std::accumulate(
294 iteratorTypes.begin(), iteratorTypes.end(), 0,
295 [](
unsigned count, utils::IteratorType iter) {
296 return count + (iter == utils::IteratorType::reduction);
299 return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
302 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
303 ArrayRef<MeshSharding> operandShardings,
304 ArrayRef<MeshSharding> resultShardings,
305 IRMapping &spmdizationMap,
306 SymbolTableCollection &symbolTable,
307 OpBuilder &builder)
const {
308 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
310 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
311 bool allIndexingMapsAreProjectedPermutation =
312 llvm::all_of(indexingMaps, [](AffineMap map) {
313 return map.isProjectedPermutation();
315 if (!allIndexingMapsAreProjectedPermutation) {
317 return op->emitOpError()
318 <<
"supports indexing maps that are only projected permutation.";
321 SmallVector<utils::IteratorType> loopIteratorTypes =
322 linalgOp.getIteratorTypesArray();
325 loopIteratorTypes, indexingMaps);
327 loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
328 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
330 linalgOp, spmdizedOperands, operandShardings, resultShardings,
331 loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
332 symbolTable, implicitLocBuilder);
335 operandShardings, resultShardings,
336 spmdizationMap, symbolTable, builder);
345 template <
typename OpType>
347 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
351 template <
typename... OpTypes>
353 (registerOne<OpTypes>(ctx), ...);
359 registry.
insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
360 tensor::TensorDialect>();
365 registerOne<linalg::GenericOp>(ctx);
368 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
auto getDialectNames() const
Return the names of dialects known to this registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
ArrayRef< MeshAxis > getPartialAxes() const
::llvm::StringRef getMesh() const
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
static MeshOp getMesh(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, SymbolTableCollection &symbolTable)
static Value createDestinationPassingStyleInitOperand(LinalgOp op, int operandNumber, Value spmdizedOperand, ArrayRef< MeshAxis > reductionMeshAxes, MeshOp meshOp, ImplicitLocOpBuilder &builder)
mesh::ReductionKind ReductionKind
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
static void createAllReduceForResultWithoutPartialSharding(Value unshardedLinalgOpResult, ArrayRef< MeshAxis > opReductionMeshAxes, MeshSharding resultSharding, ReductionKind reductionKind, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
mesh::MeshSharding MeshSharding
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, MeshOp meshOp, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshAxis > reductionMeshAxes, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static void spmdizeLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
mesh::ShardingArray ShardingArray
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< MeshAxis > opReductionMeshAxes, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...