25#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
26#include "mlir/Dialect/Linalg/Passes.h.inc"
29#define DEBUG_TYPE "linalg-specialization"
31#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
32 (rewriter.replaceOpWithNewOp<NEWOP>( \
34 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
35 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
36 ValueRange{genericOp.getDpsInits()[0]}))
57 Block *body = genericOp.getBody();
64 "binary op uses just one block arg");
90static FailureOr<LinalgOp>
92 bool emitCategoryOp) {
93 bool hasNonIdentityMaps =
94 !llvm::all_of(genericOp.getIndexingMapsArray(),
95 [](
AffineMap map) { return map.isIdentity(); });
98 if (hasNonIdentityMaps && !emitCategoryOp)
101 "non-identity indexing maps prevent specialization to named op");
106 auto replaceUnaryOp = [&](
auto namedOp, ElementwiseKind kind) -> LinalgOp {
109 newOp =
decltype(namedOp)::create(
110 rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
113 newOp = ElementwiseOp::create(
114 rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
115 genericOp.getDpsInits(),
116 ElementwiseKindAttr::get(rewriter.
getContext(), kind),
117 genericOp.getIndexingMaps());
124 Operation *op = &genericOp.getBody()->front();
126 if (isa<math::ExpOp>(op))
127 return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
128 if (isa<math::LogOp>(op))
129 return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
130 if (isa<math::AbsFOp>(op))
131 return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
132 if (isa<math::CeilOp>(op))
133 return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
134 if (isa<math::FloorOp>(op))
135 return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
136 if (isa<arith::NegFOp>(op))
137 return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
138 if (
auto divOp = dyn_cast<arith::DivFOp>(op)) {
139 if (
auto constOp = dyn_cast_if_present<arith::ConstantOp>(
140 divOp.getLhs().getDefiningOp()))
141 if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
142 return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
144 if (isa<math::RoundOp>(op))
145 return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
146 if (isa<math::SqrtOp>(op))
147 return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
148 if (isa<math::RsqrtOp>(op))
149 return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
150 if (
auto mulOp = dyn_cast<arith::MulFOp>(op);
151 mulOp && mulOp.getLhs() == mulOp.getRhs())
152 return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
153 if (isa<math::TanhOp>(op))
154 return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
155 if (isa<math::ErfOp>(op))
156 return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
187enum class IndexMatchResult {
201static IndexMatchResult matchOperandMap(
AffineMap map,
unsigned rowDimIdx,
202 unsigned expectedPosOfRowDim,
203 unsigned expectedPosOfColDim) {
205 auto exprOfRowDim = map.
getResults()[rowDimIdx];
206 auto exprOfColDim = map.
getResults()[rowDimIdx + 1];
211 return IndexMatchResult::Mismatch;
213 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
214 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
216 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
217 return IndexMatchResult::Match;
219 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
220 return IndexMatchResult::Transposed;
222 return IndexMatchResult::Mismatch;
231template <
typename NamedOpTy>
232static LinalgOp replaceWithMatmulVariant(
RewriterBase &rewriter, GenericOp op,
233 std::optional<TypeFn> castTy,
238 if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
240 "cast", TypeFnAttr::get(rewriter.
getContext(), *castTy));
241 attributes.push_back(castAttr);
248 return AffineMapAttr::get(map);
251 "indexing_maps", rewriter.
getArrayAttr(indexingMapsAttrVal));
252 attributes.push_back(indexingMapsAttr);
255 op,
ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
264static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
265 bool foundCastForMatmulOutput =
false;
267 genericOp.getBody()->walk([&](CastOpInterface castOp) {
276 if (!llvm::any_of(forwardSlice, [](
Operation *op) {
279 return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
281 foundCastForMatmulOutput =
true;
286 if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
287 castTyFns.push_back(TypeFn::cast_unsigned);
288 else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
289 castTyFns.push_back(TypeFn::cast_signed);
294 if (foundCastForMatmulOutput)
297 if (!castTyFns.empty()) {
301 if (!llvm::all_equal(castTyFns))
303 return castTyFns.front();
307 return TypeFn::cast_signed;
311static FailureOr<LinalgOp> specializeLinalgContractions(
RewriterBase &rewriter,
313 bool emitCategoryOp) {
314 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
318 auto mapRange = genericOp.getIndexingMapsArray();
319 if (llvm::any_of(mapRange,
328 return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
329 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
330 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
336 std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
339 genericOp,
"contains invalid cast ops for the named matmul op");
343 return replaceWithMatmulVariant<ContractOp>(
344 rewriter, genericOp, castTy, genericOp.getIndexingMapsArray());
371 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
375 auto indexingMaps = genericOp.getIndexingMapsArray();
376 if (llvm::any_of(indexingMaps, [&dims](
AffineMap m) {
378 dims.batch.size() + 2 ;
382 auto numOfBatchDims = dims.batch.size();
383 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
386 if (numOfBatchDims) {
390 if (llvm::any_of(indexingMaps, [numOfBatchDims](
AffineMap m) {
391 for (
unsigned i = 0; i < numOfBatchDims; ++i) {
394 cast<AffineDimExpr>(expr).getPosition() != i)
403 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
405 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
407 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
409 if (llvm::is_contained({a,
b, c}, IndexMatchResult::Mismatch))
413 auto *ctx = genericOp.getContext();
414 unsigned numLoopDims = numOfBatchDims + 3;
415 unsigned mIdx = numOfBatchDims;
416 unsigned nIdx = mIdx + 1;
417 unsigned kIdx = mIdx + 2;
420 auto makeMap = [&](IndexMatchResult match,
unsigned rowIdx,
unsigned colIdx) {
422 for (
unsigned i = 0; i < numOfBatchDims; ++i)
423 tensorDims.push_back(i);
424 if (match == IndexMatchResult::Transposed)
425 llvm::append_values(tensorDims, colIdx, rowIdx);
427 llvm::append_values(tensorDims, rowIdx, colIdx);
431 auto mapA = makeMap(a, mIdx, kIdx);
432 auto mapB = makeMap(
b, kIdx, nIdx);
433 auto mapC = makeMap(c, mIdx, nIdx);
438 if (numOfBatchDims) {
439 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
442 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
448template <
typename ConvOpTy>
449static FailureOr<LinalgOp>
450specializeToConvOp(
RewriterBase &rewriter, GenericOp genericOp,
459 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
460 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
461 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
468 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
474static FailureOr<LinalgOp> specializeLinalgConvolutions(
RewriterBase &rewriter,
475 GenericOp genericOp) {
476#define CONV_OP_SPECIALIZER(ConvOpTy) \
477 if (std::optional<DilationsAndStrides> convParams = \
478 matchConvolutionOpOfType<ConvOpTy>(genericOp)) \
479 return specializeToConvOp<ConvOpTy>( \
480 rewriter, genericOp, convParams->dilations, convParams->strides); \
537#undef CONV_OP_SPECIALIZER
557 return specializeLinalgContractions(rewriter, genericOp,
566 genericOp,
"no matching category op specialization");
571 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
579 genericOp, *fillValue, genericOp.getDpsInits()[0]);
584 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
586 if (equivalentToBroadcast) {
587 auto dims = *equivalentToBroadcast;
589 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
595 std::optional<SmallVector<int64_t>> equivalentToTranspose =
597 if (equivalentToTranspose) {
598 auto permutation = *equivalentToTranspose;
600 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
608 Operation *op = &genericOp.getBody()->front();
609 if (isa<arith::AddFOp>(op)) {
613 if (isa<arith::SubFOp>(op)) {
617 if (isa<arith::MulFOp>(op)) {
621 if (isa<arith::DivFOp>(op)) {
629 return specializeLinalgConvolutions(rewriter, genericOp);
632 "no matching named op specialization");
636struct LinalgSpecializeGenericOpsPass
637 :
public impl::LinalgSpecializeGenericOpsPassBase<
638 LinalgSpecializeGenericOpsPass> {
640 using impl::LinalgSpecializeGenericOpsPassBase<
641 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
642 void runOnOperation()
override;
646void LinalgSpecializeGenericOpsPass::runOnOperation() {
static llvm::ManagedStatic< PassManagerOptions > options
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
static FailureOr< LinalgOp > specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp, bool emitCategoryOp)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
OpOperand & getOpOperand(unsigned idx)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
static WalkResult advance()
static WalkResult interrupt()
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns, const GenericOpSpecializationOptions &options={})
Populates patterns with patterns to convert linalg.generic ops to named or category ops where possibl...
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
llvm::SetVector< T, Vector, Set, N > SetVector
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.