25#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
26#include "mlir/Dialect/Linalg/Passes.h.inc"
29#define DEBUG_TYPE "linalg-specialization"
50 Block *body = genericOp.getBody();
57 "binary op uses just one block arg");
81 Block *body = genericOp.getBody();
83 for (
auto [i, v] : llvm::enumerate(op->
getOperands())) {
84 if (
auto blockArg = dyn_cast<BlockArgument>(v);
85 blockArg && blockArg.getOwner() == body)
123 bool emitCategoryOp) {
124 bool hasNonIdentityMaps =
125 !llvm::all_of(genericOp.getIndexingMapsArray(),
126 [](
AffineMap map) { return map.isIdentity(); });
129 if (hasNonIdentityMaps && !emitCategoryOp)
132 "non-identity indexing maps prevent specialization to named op");
135 bool isUnary = genericOp.getNumDpsInputs() == 1;
136 bool isBinary = genericOp.getNumDpsInputs() == 2;
139 Operation *op = &genericOp.getBody()->front();
143 int scalarOprIdx = -1;
150 auto replaceOp = [&](
auto namedOp, ElementwiseKind kind,
151 bool mayHoistScalarOperand =
true) -> LinalgOp {
153 if (hasSwappedOperands)
154 std::swap(inputs[0], inputs[1]);
157 if (!emitCategoryOp) {
158 using NamedOpTy =
decltype(namedOp);
159 if constexpr (!std::is_null_pointer_v<NamedOpTy>)
160 newOp = NamedOpTy::create(rewriter, genericOp.getLoc(), inputs,
161 genericOp.getDpsInits(),
164 llvm_unreachable(
"Missing named op type");
168 if (hasSwappedOperands)
169 std::swap(indexingMaps[0], indexingMaps[1]);
173 if (hasScalarOperand && mayHoistScalarOperand) {
175 inputs.insert(inputs.begin() + scalarOprIdx,
177 auto scalarBroadcastMap =
180 indexingMaps.insert(indexingMaps.begin() + scalarOprIdx,
183 newOp = ElementwiseOp::create(
184 rewriter, genericOp.getLoc(), inputs, genericOp.getDpsInits(),
185 ElementwiseKindAttr::get(rewriter.
getContext(), kind),
194 if (isa<math::ExpOp>(op))
195 return replaceOp(ExpOp{}, ElementwiseKind::exp);
196 if (isa<math::LogOp>(op))
197 return replaceOp(LogOp{}, ElementwiseKind::log);
198 if (isa<math::AbsFOp>(op))
199 return replaceOp(AbsOp{}, ElementwiseKind::abs);
200 if (isa<math::CeilOp>(op))
201 return replaceOp(CeilOp{}, ElementwiseKind::ceil);
202 if (isa<math::FloorOp>(op))
203 return replaceOp(FloorOp{}, ElementwiseKind::floor);
204 if (isa<arith::NegFOp>(op))
205 return replaceOp(NegFOp{}, ElementwiseKind::negf);
206 if (
auto divOp = dyn_cast<arith::DivFOp>(op)) {
207 if (
auto constOp = dyn_cast_if_present<arith::ConstantOp>(
208 divOp.getLhs().getDefiningOp()))
209 if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
210 return replaceOp(ReciprocalOp{}, ElementwiseKind::reciprocal,
213 if (isa<math::RoundOp>(op))
214 return replaceOp(RoundOp{}, ElementwiseKind::round);
215 if (isa<math::SqrtOp>(op))
216 return replaceOp(SqrtOp{}, ElementwiseKind::sqrt);
217 if (isa<math::RsqrtOp>(op))
218 return replaceOp(RsqrtOp{}, ElementwiseKind::rsqrt);
219 if (
auto mulOp = dyn_cast<arith::MulFOp>(op);
220 mulOp && mulOp.getLhs() == mulOp.getRhs())
221 return replaceOp(SquareOp{}, ElementwiseKind::square);
222 if (isa<math::TanhOp>(op))
223 return replaceOp(TanhOp{}, ElementwiseKind::tanh);
224 if (isa<math::ErfOp>(op))
225 return replaceOp(ErfOp{}, ElementwiseKind::erf);
229 if (emitCategoryOp) {
230 if (isa<math::SinOp>(op))
231 return replaceOp(
nullptr, ElementwiseKind::sin);
232 if (isa<math::CosOp>(op))
233 return replaceOp(
nullptr, ElementwiseKind::cos);
234 if (isa<math::TanOp>(op))
235 return replaceOp(
nullptr, ElementwiseKind::tan);
236 if (isa<math::AcosOp>(op))
237 return replaceOp(
nullptr, ElementwiseKind::acos);
238 if (isa<math::AcoshOp>(op))
239 return replaceOp(
nullptr, ElementwiseKind::acosh);
240 if (isa<math::AsinOp>(op))
241 return replaceOp(
nullptr, ElementwiseKind::asin);
242 if (isa<math::AsinhOp>(op))
243 return replaceOp(
nullptr, ElementwiseKind::asinh);
244 if (isa<math::AtanOp>(op))
245 return replaceOp(
nullptr, ElementwiseKind::atan);
246 if (isa<math::AtanhOp>(op))
247 return replaceOp(
nullptr, ElementwiseKind::atanh);
248 if (isa<math::Log10Op>(op))
249 return replaceOp(
nullptr, ElementwiseKind::log10);
250 if (isa<math::Log1pOp>(op))
251 return replaceOp(
nullptr, ElementwiseKind::log1p);
252 if (isa<math::Log2Op>(op))
253 return replaceOp(
nullptr, ElementwiseKind::log2);
260 if (!emitCategoryOp || !hasScalarOperand)
262 genericOp,
"unary elementwise operation cannot be specialized to "
263 "named or category op");
268 [](
Value v) { return v.getType().isInteger(1); });
270 if (isa<arith::AddIOp, arith::AddFOp, complex::AddOp>(op) ||
271 (allBool && isa<arith::OrIOp>(op)))
272 return replaceOp(AddOp{}, ElementwiseKind::add);
273 if (isa<arith::SubIOp, arith::SubFOp, complex::SubOp>(op))
274 return replaceOp(SubOp{}, ElementwiseKind::sub);
275 if (isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op) ||
276 (allBool && isa<arith::AndIOp>(op)))
277 return replaceOp(MulOp{}, ElementwiseKind::mul);
278 if (isa<arith::DivSIOp, arith::DivFOp, complex::DivOp>(op))
279 return replaceOp(DivOp{}, ElementwiseKind::div);
280 if (isa<arith::DivUIOp>(op))
281 return replaceOp(DivUnsignedOp{}, ElementwiseKind::div_unsigned);
282 if (isa<arith::MaxSIOp, arith::MaximumFOp>(op))
283 return replaceOp(MaxOp{}, ElementwiseKind::max_signed);
284 if (isa<arith::MinSIOp, arith::MinimumFOp>(op))
285 return replaceOp(MinOp{}, ElementwiseKind::min_signed);
286 if (emitCategoryOp) {
288 if (isa<arith::MaxUIOp>(op))
289 return replaceOp(
nullptr, ElementwiseKind::max_unsigned);
290 if (isa<arith::MinUIOp>(op))
291 return replaceOp(
nullptr, ElementwiseKind::min_unsigned);
293 if (isa<math::PowFOp>(op))
294 return replaceOp(PowFOp{}, ElementwiseKind::powf);
298 "elementwise operation cannot be specialized to named or category op");
327enum class IndexMatchResult {
341static IndexMatchResult matchOperandMap(
AffineMap map,
unsigned rowDimIdx,
342 unsigned expectedPosOfRowDim,
343 unsigned expectedPosOfColDim) {
345 auto exprOfRowDim = map.
getResults()[rowDimIdx];
346 auto exprOfColDim = map.
getResults()[rowDimIdx + 1];
351 return IndexMatchResult::Mismatch;
353 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
354 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
356 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
357 return IndexMatchResult::Match;
359 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
360 return IndexMatchResult::Transposed;
362 return IndexMatchResult::Mismatch;
371template <
typename NamedOpTy>
372static LinalgOp replaceWithMatmulVariant(
RewriterBase &rewriter, GenericOp op,
373 std::optional<TypeFn> castTy,
378 if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
380 "cast", TypeFnAttr::get(rewriter.
getContext(), *castTy));
381 attributes.push_back(castAttr);
388 return AffineMapAttr::get(map);
391 "indexing_maps", rewriter.
getArrayAttr(indexingMapsAttrVal));
392 attributes.push_back(indexingMapsAttr);
395 op,
ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
404static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
405 bool foundCastForMatmulOutput =
false;
407 genericOp.getBody()->walk([&](CastOpInterface castOp) {
416 if (!llvm::any_of(forwardSlice, [](
Operation *op) {
419 return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
421 foundCastForMatmulOutput =
true;
426 if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
427 castTyFns.push_back(TypeFn::cast_unsigned);
428 else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
429 castTyFns.push_back(TypeFn::cast_signed);
434 if (foundCastForMatmulOutput)
437 if (!castTyFns.empty()) {
441 if (!llvm::all_equal(castTyFns))
443 return castTyFns.front();
447 return TypeFn::cast_signed;
450static FailureOr<LinalgOp> specializeLinalgMmt4D(
RewriterBase &rewriter,
452 std::optional<TypeFn> castTy,
455 auto indexingMaps = genericOp.getIndexingMapsArray();
456 if (llvm::any_of(indexingMaps, [](
AffineMap m) {
461 auto aOuter = matchOperandMap(indexingMaps[0], 0, dims.
m[0], dims.
k[0]);
462 auto aInner = matchOperandMap(indexingMaps[0], 2, dims.
m[1], dims.
k[1]);
464 auto bOuter = matchOperandMap(indexingMaps[1], 0, dims.
k[0], dims.
n[0]);
465 auto bInner = matchOperandMap(indexingMaps[1], 2, dims.
k[1], dims.
n[1]);
467 auto cOuter = matchOperandMap(indexingMaps[2], 0, dims.
m[0], dims.
n[0]);
468 auto cInner = matchOperandMap(indexingMaps[2], 2, dims.
m[1], dims.
n[1]);
470 if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
472 if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
478 return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
483 if (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second))
485 if (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second))
487 if (isa<complex::MulOp>(first) && isa<complex::AddOp>(second))
489 if (isa<arith::AndIOp>(first) && isa<arith::OrIOp>(second) &&
497static FailureOr<LinalgOp> specializeLinalgContractions(
RewriterBase &rewriter,
499 bool emitCategoryOp) {
500 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
504 auto mapRange = genericOp.getIndexingMapsArray();
505 if (llvm::any_of(mapRange,
514 isSupportedContractionPair))
519 std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
522 genericOp,
"contains invalid cast ops for the named matmul op");
526 return replaceWithMatmulVariant<ContractOp>(
527 rewriter, genericOp, castTy, genericOp.getIndexingMapsArray());
554 if (dims.
m.size() == 2 && dims.
n.size() == 2 && dims.
k.size() == 2)
555 return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
556 if (dims.
m.size() != 1 || dims.
n.size() != 1 || dims.
k.size() != 1)
560 auto indexingMaps = genericOp.getIndexingMapsArray();
561 if (llvm::any_of(indexingMaps, [&dims](
AffineMap m) {
563 dims.
batch.size() + 2 ;
567 auto numOfBatchDims = dims.
batch.size();
568 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
571 if (numOfBatchDims) {
575 if (llvm::any_of(indexingMaps, [numOfBatchDims](
AffineMap m) {
576 for (
unsigned i = 0; i < numOfBatchDims; ++i) {
579 cast<AffineDimExpr>(expr).getPosition() != i)
588 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.
m[0], dims.
k[0]);
590 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.
k[0], dims.
n[0]);
592 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.
m[0], dims.
n[0]);
594 if (llvm::is_contained({a,
b, c}, IndexMatchResult::Mismatch))
598 auto *ctx = genericOp.getContext();
599 unsigned numLoopDims = numOfBatchDims + 3;
600 unsigned mIdx = numOfBatchDims;
601 unsigned nIdx = mIdx + 1;
602 unsigned kIdx = mIdx + 2;
605 auto makeMap = [&](IndexMatchResult match,
unsigned rowIdx,
unsigned colIdx) {
607 for (
unsigned i = 0; i < numOfBatchDims; ++i)
608 tensorDims.push_back(i);
609 if (match == IndexMatchResult::Transposed)
610 llvm::append_values(tensorDims, colIdx, rowIdx);
612 llvm::append_values(tensorDims, rowIdx, colIdx);
616 auto mapA = makeMap(a, mIdx, kIdx);
617 auto mapB = makeMap(
b, kIdx, nIdx);
618 auto mapC = makeMap(c, mIdx, nIdx);
623 if (numOfBatchDims) {
624 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
627 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
633template <
typename ConvOpTy>
634static FailureOr<LinalgOp>
635specializeToConvOp(
RewriterBase &rewriter, GenericOp genericOp,
644 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
645 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
646 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
653 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
659static FailureOr<LinalgOp> specializeLinalgConvolutions(
RewriterBase &rewriter,
660 GenericOp genericOp) {
661#define CONV_OP_SPECIALIZER(ConvOpTy) \
662 if (std::optional<DilationsAndStrides> convParams = \
663 matchConvolutionOpOfType<ConvOpTy>(genericOp)) \
664 return specializeToConvOp<ConvOpTy>( \
665 rewriter, genericOp, convParams->dilations, convParams->strides); \
722#undef CONV_OP_SPECIALIZER
743 return specializeLinalgContractions(rewriter, genericOp,
752 genericOp,
"no matching category op specialization");
757 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
765 genericOp, *fillValue, genericOp.getDpsInits()[0]);
770 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
772 if (equivalentToBroadcast) {
773 auto dims = *equivalentToBroadcast;
775 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
781 std::optional<SmallVector<int64_t>> equivalentToTranspose =
783 if (equivalentToTranspose) {
784 auto permutation = *equivalentToTranspose;
786 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
793 return specializeLinalgConvolutions(rewriter, genericOp);
796 "no matching named op specialization");
800struct LinalgSpecializeGenericOpsPass
802 LinalgSpecializeGenericOpsPass> {
805 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
806 void runOnOperation()
override;
810void LinalgSpecializeGenericOpsPass::runOnOperation() {
static llvm::ManagedStatic< PassManagerOptions > options
static bool findIndexOfScalarOperand(GenericOp genericOp, int &index)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
static FailureOr< LinalgOp > specializeLinalgElementwise(RewriterBase &rewriter, GenericOp genericOp, bool emitCategoryOp)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
OpOperand & getOpOperand(unsigned idx)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns, const GenericOpSpecializationOptions &options={})
Populates patterns with patterns to convert linalg.generic ops to named or category ops where possibl...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
llvm::SetVector< T, Vector, Set, N > SetVector
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
SmallVector< unsigned, 2 > k