25#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
26#include "mlir/Dialect/Linalg/Passes.h.inc"
29#define DEBUG_TYPE "linalg-specialization"
50 Block *body = genericOp.getBody();
57 "binary op uses just one block arg");
81 Block *body = genericOp.getBody();
83 for (
auto [i, v] : llvm::enumerate(op->
getOperands())) {
84 if (
auto blockArg = dyn_cast<BlockArgument>(v);
85 blockArg && blockArg.getOwner() == body)
123 bool emitCategoryOp) {
124 bool hasNonIdentityMaps =
125 !llvm::all_of(genericOp.getIndexingMapsArray(),
126 [](
AffineMap map) { return map.isIdentity(); });
129 if (hasNonIdentityMaps && !emitCategoryOp)
132 "non-identity indexing maps prevent specialization to named op");
135 bool isUnary = genericOp.getNumDpsInputs() == 1;
136 bool isBinary = genericOp.getNumDpsInputs() == 2;
139 Operation *op = &genericOp.getBody()->front();
143 int scalarOprIdx = -1;
150 auto replaceOp = [&](
auto namedOp, ElementwiseKind kind,
151 bool mayHoistScalarOperand =
true) -> LinalgOp {
153 if (hasSwappedOperands)
154 std::swap(inputs[0], inputs[1]);
157 if (!emitCategoryOp) {
158 using NamedOpTy =
decltype(namedOp);
159 if constexpr (!std::is_null_pointer_v<NamedOpTy>)
160 newOp = NamedOpTy::create(rewriter, genericOp.getLoc(), inputs,
161 genericOp.getDpsInits(),
164 llvm_unreachable(
"Missing named op type");
168 if (hasSwappedOperands)
169 std::swap(indexingMaps[0], indexingMaps[1]);
173 if (hasScalarOperand && mayHoistScalarOperand) {
175 inputs.insert(inputs.begin() + scalarOprIdx,
177 auto scalarBroadcastMap =
180 indexingMaps.insert(indexingMaps.begin() + scalarOprIdx,
183 newOp = ElementwiseOp::create(
184 rewriter, genericOp.getLoc(), inputs, genericOp.getDpsInits(),
185 ElementwiseKindAttr::get(rewriter.
getContext(), kind),
194 if (isa<math::ExpOp>(op))
195 return replaceOp(ExpOp{}, ElementwiseKind::exp);
196 if (isa<math::LogOp>(op))
197 return replaceOp(LogOp{}, ElementwiseKind::log);
198 if (isa<math::AbsFOp>(op))
199 return replaceOp(AbsOp{}, ElementwiseKind::abs);
200 if (isa<math::CeilOp>(op))
201 return replaceOp(CeilOp{}, ElementwiseKind::ceil);
202 if (isa<math::FloorOp>(op))
203 return replaceOp(FloorOp{}, ElementwiseKind::floor);
204 if (isa<arith::NegFOp>(op))
205 return replaceOp(NegFOp{}, ElementwiseKind::negf);
206 if (
auto divOp = dyn_cast<arith::DivFOp>(op)) {
207 if (
auto constOp = dyn_cast_if_present<arith::ConstantOp>(
208 divOp.getLhs().getDefiningOp()))
209 if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
210 return replaceOp(ReciprocalOp{}, ElementwiseKind::reciprocal,
213 if (isa<math::RoundOp>(op))
214 return replaceOp(RoundOp{}, ElementwiseKind::round);
215 if (isa<math::SqrtOp>(op))
216 return replaceOp(SqrtOp{}, ElementwiseKind::sqrt);
217 if (isa<math::RsqrtOp>(op))
218 return replaceOp(RsqrtOp{}, ElementwiseKind::rsqrt);
219 if (
auto mulOp = dyn_cast<arith::MulFOp>(op);
220 mulOp && mulOp.getLhs() == mulOp.getRhs())
221 return replaceOp(SquareOp{}, ElementwiseKind::square);
222 if (isa<math::TanhOp>(op))
223 return replaceOp(TanhOp{}, ElementwiseKind::tanh);
224 if (isa<math::ErfOp>(op))
225 return replaceOp(ErfOp{}, ElementwiseKind::erf);
231 if (!emitCategoryOp || !hasScalarOperand)
233 genericOp,
"unary elementwise operation cannot be specialized to "
234 "named or category op");
239 [](
Value v) { return v.getType().isInteger(1); });
241 if (isa<arith::AddIOp, arith::AddFOp, complex::AddOp>(op) ||
242 (allBool && isa<arith::OrIOp>(op)))
243 return replaceOp(AddOp{}, ElementwiseKind::add);
244 if (isa<arith::SubIOp, arith::SubFOp, complex::SubOp>(op))
245 return replaceOp(SubOp{}, ElementwiseKind::sub);
246 if (isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op) ||
247 (allBool && isa<arith::AndIOp>(op)))
248 return replaceOp(MulOp{}, ElementwiseKind::mul);
249 if (isa<arith::DivSIOp, arith::DivFOp, complex::DivOp>(op))
250 return replaceOp(DivOp{}, ElementwiseKind::div);
251 if (isa<arith::DivUIOp>(op))
252 return replaceOp(DivUnsignedOp{}, ElementwiseKind::div_unsigned);
253 if (isa<arith::MaxSIOp, arith::MaximumFOp>(op))
254 return replaceOp(MaxOp{}, ElementwiseKind::max_signed);
255 if (isa<arith::MinSIOp, arith::MinimumFOp>(op))
256 return replaceOp(MinOp{}, ElementwiseKind::min_signed);
257 if (emitCategoryOp) {
259 if (isa<arith::MaxUIOp>(op))
260 return replaceOp(
nullptr, ElementwiseKind::max_unsigned);
261 if (isa<arith::MinUIOp>(op))
262 return replaceOp(
nullptr, ElementwiseKind::min_unsigned);
264 if (isa<math::PowFOp>(op))
265 return replaceOp(PowFOp{}, ElementwiseKind::powf);
269 "elementwise operation cannot be specialized to named or category op");
298enum class IndexMatchResult {
312static IndexMatchResult matchOperandMap(
AffineMap map,
unsigned rowDimIdx,
313 unsigned expectedPosOfRowDim,
314 unsigned expectedPosOfColDim) {
316 auto exprOfRowDim = map.
getResults()[rowDimIdx];
317 auto exprOfColDim = map.
getResults()[rowDimIdx + 1];
322 return IndexMatchResult::Mismatch;
324 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
325 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
327 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
328 return IndexMatchResult::Match;
330 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
331 return IndexMatchResult::Transposed;
333 return IndexMatchResult::Mismatch;
342template <
typename NamedOpTy>
343static LinalgOp replaceWithMatmulVariant(
RewriterBase &rewriter, GenericOp op,
344 std::optional<TypeFn> castTy,
349 if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
351 "cast", TypeFnAttr::get(rewriter.
getContext(), *castTy));
352 attributes.push_back(castAttr);
359 return AffineMapAttr::get(map);
362 "indexing_maps", rewriter.
getArrayAttr(indexingMapsAttrVal));
363 attributes.push_back(indexingMapsAttr);
366 op,
ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
375static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
376 bool foundCastForMatmulOutput =
false;
378 genericOp.getBody()->walk([&](CastOpInterface castOp) {
387 if (!llvm::any_of(forwardSlice, [](
Operation *op) {
390 return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
392 foundCastForMatmulOutput =
true;
397 if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
398 castTyFns.push_back(TypeFn::cast_unsigned);
399 else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
400 castTyFns.push_back(TypeFn::cast_signed);
405 if (foundCastForMatmulOutput)
408 if (!castTyFns.empty()) {
412 if (!llvm::all_equal(castTyFns))
414 return castTyFns.front();
418 return TypeFn::cast_signed;
421static FailureOr<LinalgOp> specializeLinalgMmt4D(
RewriterBase &rewriter,
423 std::optional<TypeFn> castTy,
426 auto indexingMaps = genericOp.getIndexingMapsArray();
427 if (llvm::any_of(indexingMaps, [](
AffineMap m) {
432 auto aOuter = matchOperandMap(indexingMaps[0], 0, dims.
m[0], dims.
k[0]);
433 auto aInner = matchOperandMap(indexingMaps[0], 2, dims.
m[1], dims.
k[1]);
435 auto bOuter = matchOperandMap(indexingMaps[1], 0, dims.
k[0], dims.
n[0]);
436 auto bInner = matchOperandMap(indexingMaps[1], 2, dims.
k[1], dims.
n[1]);
438 auto cOuter = matchOperandMap(indexingMaps[2], 0, dims.
m[0], dims.
n[0]);
439 auto cInner = matchOperandMap(indexingMaps[2], 2, dims.
m[1], dims.
n[1]);
441 if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
443 if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
449 return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
454static FailureOr<LinalgOp> specializeLinalgContractions(
RewriterBase &rewriter,
456 bool emitCategoryOp) {
457 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
461 auto mapRange = genericOp.getIndexingMapsArray();
462 if (llvm::any_of(mapRange,
471 return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
472 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
473 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
479 std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
482 genericOp,
"contains invalid cast ops for the named matmul op");
486 return replaceWithMatmulVariant<ContractOp>(
487 rewriter, genericOp, castTy, genericOp.getIndexingMapsArray());
514 if (dims.
m.size() == 2 && dims.
n.size() == 2 && dims.
k.size() == 2)
515 return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
516 if (dims.
m.size() != 1 || dims.
n.size() != 1 || dims.
k.size() != 1)
520 auto indexingMaps = genericOp.getIndexingMapsArray();
521 if (llvm::any_of(indexingMaps, [&dims](
AffineMap m) {
523 dims.
batch.size() + 2 ;
527 auto numOfBatchDims = dims.
batch.size();
528 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
531 if (numOfBatchDims) {
535 if (llvm::any_of(indexingMaps, [numOfBatchDims](
AffineMap m) {
536 for (
unsigned i = 0; i < numOfBatchDims; ++i) {
539 cast<AffineDimExpr>(expr).getPosition() != i)
548 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.
m[0], dims.
k[0]);
550 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.
k[0], dims.
n[0]);
552 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.
m[0], dims.
n[0]);
554 if (llvm::is_contained({a,
b, c}, IndexMatchResult::Mismatch))
558 auto *ctx = genericOp.getContext();
559 unsigned numLoopDims = numOfBatchDims + 3;
560 unsigned mIdx = numOfBatchDims;
561 unsigned nIdx = mIdx + 1;
562 unsigned kIdx = mIdx + 2;
565 auto makeMap = [&](IndexMatchResult match,
unsigned rowIdx,
unsigned colIdx) {
567 for (
unsigned i = 0; i < numOfBatchDims; ++i)
568 tensorDims.push_back(i);
569 if (match == IndexMatchResult::Transposed)
570 llvm::append_values(tensorDims, colIdx, rowIdx);
572 llvm::append_values(tensorDims, rowIdx, colIdx);
576 auto mapA = makeMap(a, mIdx, kIdx);
577 auto mapB = makeMap(
b, kIdx, nIdx);
578 auto mapC = makeMap(c, mIdx, nIdx);
583 if (numOfBatchDims) {
584 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
587 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
593template <
typename ConvOpTy>
594static FailureOr<LinalgOp>
595specializeToConvOp(
RewriterBase &rewriter, GenericOp genericOp,
604 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
605 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
606 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
613 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
619static FailureOr<LinalgOp> specializeLinalgConvolutions(
RewriterBase &rewriter,
620 GenericOp genericOp) {
621#define CONV_OP_SPECIALIZER(ConvOpTy) \
622 if (std::optional<DilationsAndStrides> convParams = \
623 matchConvolutionOpOfType<ConvOpTy>(genericOp)) \
624 return specializeToConvOp<ConvOpTy>( \
625 rewriter, genericOp, convParams->dilations, convParams->strides); \
682#undef CONV_OP_SPECIALIZER
703 return specializeLinalgContractions(rewriter, genericOp,
712 genericOp,
"no matching category op specialization");
717 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
725 genericOp, *fillValue, genericOp.getDpsInits()[0]);
730 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
732 if (equivalentToBroadcast) {
733 auto dims = *equivalentToBroadcast;
735 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
741 std::optional<SmallVector<int64_t>> equivalentToTranspose =
743 if (equivalentToTranspose) {
744 auto permutation = *equivalentToTranspose;
746 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
753 return specializeLinalgConvolutions(rewriter, genericOp);
756 "no matching named op specialization");
760struct LinalgSpecializeGenericOpsPass
762 LinalgSpecializeGenericOpsPass> {
765 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
766 void runOnOperation()
override;
770void LinalgSpecializeGenericOpsPass::runOnOperation() {
static llvm::ManagedStatic< PassManagerOptions > options
static bool findIndexOfScalarOperand(GenericOp genericOp, int &index)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
static FailureOr< LinalgOp > specializeLinalgElementwise(RewriterBase &rewriter, GenericOp genericOp, bool emitCategoryOp)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
unsigned getNumOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
OpOperand & getOpOperand(unsigned idx)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static WalkResult interrupt()
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns, const GenericOpSpecializationOptions &options={})
Populates patterns with patterns to convert linalg.generic ops to named or category ops where possibl...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
llvm::SetVector< T, Vector, Set, N > SetVector
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
SmallVector< unsigned, 2 > k