34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/ADT/TypeSwitch.h"
54 .Case([](arith::AddFOp op) {
return ReductionKind::Sum; })
55 .Case([](arith::MulFOp op) {
return ReductionKind::Product; })
57 .Case([](arith::MaximumFOp op) {
return ReductionKind::Max; })
58 .Case([](arith::MinimumFOp op) {
return ReductionKind::Min; })
60 .Case([](arith::AddIOp op) {
return ReductionKind::Sum; })
61 .Case([](arith::OrIOp op) {
return ReductionKind::BitwiseOr; })
62 .Case([](arith::XOrIOp op) {
return ReductionKind::BitwiseXor; })
63 .Case([](arith::AndIOp op) {
return ReductionKind::Sum; })
73 .Case([](arith::MaxUIOp op) {
return ReductionKind::Max; })
74 .Case([](arith::MinUIOp op) {
return ReductionKind::Min; })
75 .Case([](arith::MaxSIOp op) {
return ReductionKind::Max; })
76 .Case([](arith::MinSIOp op) {
return ReductionKind::Min; })
77 .Case([](arith::MulIOp op) {
return ReductionKind::Product; })
78 .Default([](
Operation *op) {
return ReductionKind::Generic; });
84 if (!reducedValue || combinerOps.size() != 1) {
88 return combinerOps[0];
94 return ReductionKind::Generic;
96 [[maybe_unused]]
Type resultElementType =
97 llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
101 assert(resultElementType == reductionOp.value()->getResult(0).getType());
136 meshOp.getSymName(), reductionMeshAxes, builder);
138 Value isLeadProcess = builder.
create<arith::CmpIOp>(
139 builder.
getI1Type(), arith::CmpIPredicate::eq,
140 processLinearIndexInReductionGroup, zero);
141 scf::IfOp ifOp = builder.
create<scf::IfOp>(spmdizedOperand.
getType(),
142 isLeadProcess,
true,
true);
147 builder.
create<scf::YieldOp>(spmdizedOperand);
156 PartialReductionOpInterface partialReductionIface =
157 llvm::cast<PartialReductionOpInterface>(op.getOperation());
159 partialReductionIface.generateInitialTensorForPartialReduction(
160 builder, builder.
getLoc(), shape, {});
161 assert(
succeeded(reductionNeutralTensorOp));
162 builder.
create<scf::YieldOp>(
163 reductionNeutralTensorOp.value()->getResult(0));
165 return ifOp.getResult(0);
179 auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
180 Value spmdizedInitOperand =
181 spmdizationMap.
lookup(op->getOperands()[operandIdx]);
183 op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
192 llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
194 return !llvm::is_contained(resultSharding.getPartialAxes(),
197 if (allReduceMeshAxes.empty()) {
201 Value spmdizedLinalgOpResult = spmdizationMap.
lookup(unshardedLinalgOpResult);
202 Value reducedValue = builder.
create<mesh::AllReduceOp>(
203 spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
204 allReduceMeshAxes, reductionKind);
205 spmdizationMap.
map(unshardedLinalgOpResult, reducedValue);
213 for (
auto [unshardedLinalgOpResult, resultSharding] :
214 llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
216 unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
217 reductionKind, spmdizationMap, builder);
229 MeshOp mesh =
getMesh(op, operandShardings, resultShardings, symbolTable);
231 loopIteratorTypes, meshAxisAssignmentForLoopIterators);
235 spmdizationMap, builder);
240 for (
auto [unshardedOperand, spmdizedOperand] :
241 llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
242 internalSpmdizationMap.
map(unshardedOperand, spmdizedOperand);
245 *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
246 internalSpmdizationMap, symbolTable, builder);
247 for (
Value result : op->getResults()) {
248 spmdizationMap.
map(result, internalSpmdizationMap.
lookup(result));
253 op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
261 template <
typename Op>
262 struct StructuredOpShardingInterface
263 :
public mesh::ShardingInterface::ExternalModel<
264 StructuredOpShardingInterface<Op>, Op> {
266 return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
269 SmallVector<AffineMap> getIndexingMaps(Operation *op)
const {
270 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
271 SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
275 for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
276 res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
282 LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
283 ArrayRef<MeshShardingAttr> operandShardings,
284 ArrayRef<MeshShardingAttr> resultShardings,
285 IRMapping &spmdizationMap,
286 SymbolTableCollection &symbolTable,
287 OpBuilder &builder)
const {
288 LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
290 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
291 bool allIndexingMapsAreProjectedPermutation =
292 llvm::all_of(indexingMaps, [](AffineMap map) {
293 return map.isProjectedPermutation();
295 if (!allIndexingMapsAreProjectedPermutation) {
297 return op->emitOpError()
298 <<
"supports indexing maps that are only projected permutation.";
301 SmallVector<utils::IteratorType> loopIteratorTypes =
302 linalgOp.getIteratorTypesArray();
305 loopIteratorTypes, indexingMaps);
307 loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
308 ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
310 linalgOp, spmdizedOperands, operandShardings, resultShardings,
311 loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
312 symbolTable, implicitLocBuilder);
315 operandShardings, resultShardings,
316 spmdizationMap, symbolTable, builder);
325 template <
typename OpType>
327 OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
331 template <
typename... OpTypes>
333 (registerOne<OpTypes>(ctx), ...);
339 registry.
insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
340 tensor::TensorDialect>();
345 registerOne<linalg::GenericOp>(ctx);
348 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
auto getDialectNames() const
Return the names of dialects known to this registry.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
static MeshOp getMesh(Operation *op, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, SymbolTableCollection &symbolTable)
mesh::ReductionKind ReductionKind
static void createAllReduceForResultsWithoutPartialShardings(LinalgOp unshardedOp, ArrayRef< MeshAxis > opReductionMeshAxes, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
static ReductionKind getReductionKind(Operation *op)
static std::optional< Operation * > getCombinerOp(LinalgOp op)
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
mesh::MeshShardingAttr MeshShardingAttr
static void spmdizeLinalgOpWithShardedReduction(LinalgOp op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder)
static Value createDestinationPassingStyleInitOperand(LinalgOp op, Value spmdizedOperand, ArrayRef< MeshAxis > reductionMeshAxes, MeshOp meshOp, ImplicitLocOpBuilder &builder)
static SmallVector< Value > createDestinationPassingStyleInitOperands(LinalgOp op, MeshOp meshOp, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshAxis > reductionMeshAxes, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
mesh::ShardingArray ShardingArray
static ReductionKind getReductionKindOfLinalgOp(LinalgOp op)
static void createAllReduceForResultWithoutPartialSharding(Value unshardedLinalgOpResult, ArrayRef< MeshAxis > opReductionMeshAxes, MeshShardingAttr resultSharding, ReductionKind reductionKind, IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder)
void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
TypedValue< IndexType > createProcessLinearIndex(StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...