MLIR
20.0.0git
|
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <type_traits>
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
Go to the source code of this file.
Macros | |
#define | DEBUG_TYPE "linalg-transforms" |
#define | DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
#define | DBGSNL() (llvm::dbgs() << "\n") |
#define | LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") |
#define | DOWNSCALE(trans) |
#define | DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b> |
#define | DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b)) |
#define | GET_OP_CLASSES |
Functions | |
template<typename PatternTy , typename... Args> | |
static FailureOr< LinalgOp > | tryApply (Operation *operation, Args &&...args) |
Attempts to apply the pattern specified as template argument to the given operation. More... | |
static DiagnosedSilenceableFailure | unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs) |
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value. More... | |
static DiagnosedSilenceableFailure | unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, Value packedHandle) |
static DiagnosedSilenceableFailure | reifyMixedParamAndHandleResults (TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified) |
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred. More... | |
template<typename Range > | |
static LogicalResult | applyTilingToAll (RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn) |
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops. More... | |
static Operation * | replaceForAllWithNewSignature (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes) |
Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op. More... | |
static std::tuple< SmallVector< Operation * >, Operation * > | tileAndFuseFirstExtractUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
Find the first "extract" user of producerOp and tile it right before its use. More... | |
static SmallVector< Operation * > | tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp , otherwise bail. More... | |
static Operation * | cloneAndFuseFirstUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
static void | printMultitileSizesTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type) |
static ParseResult | parseMultitileSizesTypes (OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType) |
template<typename RelayoutOpTy > | |
bool | isValidPackingPermutation (RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer) |
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op. More... | |
static void | printContinuousTileSizeTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type) |
static ParseResult | parseContinuousTileSizeTypes (OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType) |
static SmallVector< OpFoldResult > | normalizeUpperBounds (RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps) |
Given lbs , ubs and steps of loops, return (for each loop), the normalized upper bound. More... | |
static SmallVector< Value > | denormalizeIndVar (RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps) |
When a loop is normalized, the uses of the induction variable within the loop need to replaced with original_lb + old_iv * original_step . More... | |
static scf::ForallOp | normalizeForallLoopOp (RewriterBase &rewriter, scf::ForallOp loop) |
Given a scf.forall loop return a loop op with the loop bounds normalized. More... | |
template<typename OpTy > | |
DiagnosedSilenceableFailure | doit (RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state) |
#define DBGS | ( | ) | (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
Definition at line 55 of file LinalgTransformOps.cpp.
#define DBGSNL | ( | ) | (llvm::dbgs() << "\n") |
Definition at line 56 of file LinalgTransformOps.cpp.
#define DEBUG_TYPE "linalg-transforms" |
Definition at line 54 of file LinalgTransformOps.cpp.
#define DOWNSCALE | ( | trans | ) |
#define DOWNSCALE_CALL | ( | a, | |
b | |||
) | DownscaleSizeOneWindowed2DConvolution<a, b> |
#define DOWNSCALE_NORMAL | ( | a, | |
b | |||
) | DOWNSCALE(DOWNSCALE_CALL(a, b)) |
#define GET_OP_CLASSES |
Definition at line 3936 of file LinalgTransformOps.cpp.
#define LDBG | ( | X | ) | LLVM_DEBUG(DBGS() << (X) << "\n") |
Definition at line 57 of file LinalgTransformOps.cpp.
|
static |
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops.
Definition at line 517 of file LinalgTransformOps.cpp.
References mlir::Operation::emitError(), mlir::RewriterBase::eraseOp(), mlir::Operation::getOpResult(), mlir::RewriterBase::replaceAllUsesWith(), mlir::transform::TransformResults::set(), and mlir::OpBuilder::setInsertionPoint().
|
static |
Definition at line 902 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::clone(), DBGS, diag(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::Operation::getOpResults(), mlir::detail::IROperandBase::getOwner(), mlir::Operation::isProperAncestor(), mlir::RewriterBase::modifyOpInPlace(), and mlir::OpBuilder::setInsertionPoint().
|
static |
When a loop is normalized, the uses of the induction variable within the loop need to replaced with original_lb + old_iv * original_step
.
Definition at line 3207 of file LinalgTransformOps.cpp.
References mlir::bindDims(), mlir::bindSymbols(), mlir::Builder::getContext(), mlir::getValueOrCreateConstantIndexOp(), and mlir::affine::makeComposedFoldedAffineApply().
Referenced by normalizeForallLoopOp().
DiagnosedSilenceableFailure doit | ( | RewriterBase & | rewriter, |
OpTy | target, | ||
transform::ApplyToEachResultList & | results, | ||
transform::TransformState & | state | ||
) |
Definition at line 3734 of file LinalgTransformOps.cpp.
Referenced by llvm::cast_convert_val< T, ::mlir::Dialect *, ::mlir::Dialect * >::doit(), and tileLinalgOpImpl().
bool isValidPackingPermutation | ( | RelayoutOpTy | op, |
ArrayRef< int64_t > | permutation, | ||
OuterOrInnerPerm | outerOrInnerPerm = OuterOrInnerPerm::Outer |
||
) |
Return true if permutation
is a valid permutation of the outer_dims_perm
(case OuterOrInnerPerm::Outer) or inner_dims_pos
(OuterOrInnerPerm::Inner) of the tensor.pack
or tensor.unpack
op.
This is the case when thepermutationrank matches the rank expected by
opand
permutationis itself a permutation vector. Return true if either
opor
permutation` are empty to allow a simpler polymorphic implementation.
Definition at line 1622 of file LinalgTransformOps.cpp.
References mlir::isPermutationVector().
|
static |
Given a scf.forall
loop return a loop op with the loop bounds normalized.
TODO: Replace this with a general utility to normalize scf.forall
. At the time of writing, this wasnt done since adding this to scf
dialect would disallow using of affine.apply
operations due to cyclic dependencies. To avoid churn in lit tests with the change this was added with, defer that to a follow up.
Definition at line 3234 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::create(), denormalizeIndVar(), mlir::Builder::getIndexAttr(), mlir::isConstantIntValue(), mlir::RewriterBase::mergeBlocks(), normalizeUpperBounds(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPointToStart().
|
static |
Given lbs
, ubs
and steps
of loops, return (for each loop), the normalized upper bound.
Definition at line 3190 of file LinalgTransformOps.cpp.
References mlir::bindSymbols(), mlir::Builder::getContext(), and mlir::affine::makeComposedFoldedAffineApply().
Referenced by normalizeForallLoopOp().
|
static |
Definition at line 2832 of file LinalgTransformOps.cpp.
References mlir::AsmParser::emitError(), mlir::AsmParser::getCurrentLocation(), and mlir::AsmParser::parseType().
|
static |
Definition at line 1348 of file LinalgTransformOps.cpp.
References mlir::AsmParser::emitError(), mlir::AsmParser::getCurrentLocation(), and mlir::AsmParser::parseType().
|
static |
Definition at line 2826 of file LinalgTransformOps.cpp.
References mlir::OpAsmPrinter::printFunctionalType().
|
static |
Definition at line 1342 of file LinalgTransformOps.cpp.
References mlir::OpAsmPrinter::printFunctionalType().
|
static |
When possible, converts each OpFoldResult
in mixedResult
to an integer if the value can be statically inferred.
If a result is a Value
then it must be either a ParamType
or a handle to an a constant like op.
Definition at line 179 of file LinalgTransformOps.cpp.
|
static |
Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op.
Definition at line 635 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::create(), mlir::DominanceInfo::dominates(), mlir::detail::enumerate(), mlir::RewriterBase::eraseBlock(), mlir::Builder::getIndexAttr(), mlir::Operation::getLoc(), mlir::Operation::getResult(), mlir::Value::getUsers(), mlir::Operation::isAncestor(), mlir::Operation::isProperAncestor(), mlir::RewriterBase::replaceAllUsesWith(), mlir::RewriterBase::replaceUsesWithIf(), mlir::OpBuilder::setInsertionPoint(), and mlir::TilingResult::tiledValues.
Referenced by tileAndFuseFirstExtractUse().
|
static |
Find the first "extract" user of producerOp
and tile it right before its use.
The tiled op is fused under the containingOp
. Return this fused op on success or nullptr if anything fails. If tiled op has uses that are dominated by containingOp
, return a new containingOp
with results of the fused op appended to results of the containingOp
or nullptr if there are no dominated uses.
Definition at line 718 of file LinalgTransformOps.cpp.
References DBGS, diag(), mlir::Operation::getLoc(), replaceForAllWithNewSignature(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPoint().
|
static |
First, find the first "scf::ForallOp" user of producerOp
and ensure it is exactly the containingOp
, otherwise bail.
Then, find the first "extract" user of the tied block argument and tile it right before its "extract" use. The tiled op is fused under the containingOp
. Return this fused op on success or nullptr if anything fails.
Definition at line 800 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::clone(), DBGS, diag(), mlir::RewriterBase::eraseOp(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::OpOperand::getOperandNumber(), mlir::tensor::getOrCreateDestinations(), mlir::Value::getUsers(), mlir::IRMapping::map(), mlir::RewriterBase::modifyOpInPlace(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), and mlir::Operation::setOperand().
|
static |
Attempts to apply the pattern specified as template argument to the given operation.
The pattern is expected to have a returningMatchAndRewrite
function that returns the "main" result or failure. Returns failure if the pattern failed to apply. Extra arguments are forwarded to the pattern constructor.
Definition at line 65 of file LinalgTransformOps.cpp.
References mlir::Operation::getContext().
|
static |
Assuming that ofr
is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value.
Definition at line 93 of file LinalgTransformOps.cpp.
|
static |
Definition at line 145 of file LinalgTransformOps.cpp.