MLIR
18.0.0git
|
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <type_traits>
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
Go to the source code of this file.
Macros | |
#define | DEBUG_TYPE "linalg-transforms" |
#define | DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
#define | DBGSNL() (llvm::dbgs() << "\n") |
#define | LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
#define | DOWNSCALE(trans) |
#define | DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b> |
#define | DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b)) |
#define | GET_OP_CLASSES |
Functions | |
template<typename PatternTy , typename... Args> | |
static FailureOr< LinalgOp > | tryApply (Operation *operation, Args &&...args) |
Attempts to apply the pattern specified as template argument to the given operation. More... | |
static DiagnosedSilenceableFailure | unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs) |
Assuming that ofr is an index attr or a transform dialect handle mapped to exactly one op with one index result, return that value. More... | |
static DiagnosedSilenceableFailure | unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, Value packedHandle) |
template<typename Range > | |
static LogicalResult | applyTilingToAll (RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn) |
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops. More... | |
static Operation * | replaceForAllWithNewSignature (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes) |
Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op. More... | |
static std::tuple< SmallVector< Operation * >, Operation * > | tileAndFuseFirstExtractUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
Find the first "extract" user of producerOp and tile it right before its use. More... | |
static SmallVector< Operation * > | tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp , otherwise bail. More... | |
static Operation * | cloneAndFuseFirstUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
static void | printMultitileSizesTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type) |
static ParseResult | parseMultitileSizesTypes (OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType) |
template<typename RelayoutOpTy > | |
bool | isValidPackingPermutation (RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer) |
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op. More... | |
ParseResult | parseOptionalInterchange (OpAsmParser &parser, OperationState &result) |
void | printOptionalInterchange (OpAsmPrinter &p, ArrayRef< int64_t > interchangeVals) |
template<typename OpTy > | |
DiagnosedSilenceableFailure | doit (RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state) |
#define DBGS | ( | ) | (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
Definition at line 54 of file LinalgTransformOps.cpp.
#define DBGSNL | ( | ) | (llvm::dbgs() << "\n") |
Definition at line 55 of file LinalgTransformOps.cpp.
#define DEBUG_TYPE "linalg-transforms" |
Definition at line 53 of file LinalgTransformOps.cpp.
#define DOWNSCALE | ( | trans | ) |
#define DOWNSCALE_CALL | ( | a, | |
b | |||
) | DownscaleSizeOneWindowed2DConvolution<a, b> |
#define DOWNSCALE_NORMAL | ( | a, | |
b | |||
) | DOWNSCALE(DOWNSCALE_CALL(a, b)) |
#define GET_OP_CLASSES |
Definition at line 3349 of file LinalgTransformOps.cpp.
#define LDBG | ( | X | ) | LLVM_DEBUG(DBGS() << X << "\n") |
Definition at line 56 of file LinalgTransformOps.cpp.
|
static |
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops.
Definition at line 420 of file LinalgTransformOps.cpp.
References mlir::Operation::emitError(), mlir::RewriterBase::eraseOp(), mlir::failed(), mlir::failure(), mlir::Operation::getOpResult(), mlir::RewriterBase::replaceAllUsesWith(), mlir::transform::TransformResults::set(), mlir::OpBuilder::setInsertionPoint(), and mlir::success().
|
static |
Definition at line 828 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::clone(), DBGS, diag(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::Operation::getOpResults(), mlir::detail::IROperandBase::getOwner(), mlir::Operation::isProperAncestor(), mlir::OpBuilder::setInsertionPoint(), and mlir::RewriterBase::updateRootInPlace().
DiagnosedSilenceableFailure doit | ( | RewriterBase & | rewriter, |
OpTy | target, | ||
transform::ApplyToEachResultList & | results, | ||
transform::TransformState & | state | ||
) |
Definition at line 3220 of file LinalgTransformOps.cpp.
Referenced by llvm::cast_convert_val< T, ::mlir::Dialect *, ::mlir::Dialect * >::doit().
bool isValidPackingPermutation | ( | RelayoutOpTy | op, |
ArrayRef< int64_t > | permutation, | ||
OuterOrInnerPerm | outerOrInnerPerm = OuterOrInnerPerm::Outer |
||
) |
Return true if permutation
is a valid permutation of the outer_dims_perm
(case OuterOrInnerPerm::Outer) or inner_dims_pos
(OuterOrInnerPerm::Inner) of the tensor.pack
or tensor.unpack
op.
This is the case when thepermutationrank matches the rank expected by
opand
permutationis itself a permutation vector. Return true if either
opor
permutation` are empty to allow a simpler polymorphic implementation.
Definition at line 1507 of file LinalgTransformOps.cpp.
|
static |
Definition at line 1233 of file LinalgTransformOps.cpp.
References mlir::AsmParser::emitError(), mlir::failed(), mlir::failure(), mlir::AsmParser::getCurrentLocation(), mlir::AsmParser::parseType(), and mlir::success().
ParseResult parseOptionalInterchange | ( | OpAsmParser & | parser, |
OperationState & | result | ||
) |
Definition at line 2659 of file LinalgTransformOps.cpp.
References mlir::OperationState::addAttribute(), mlir::failed(), mlir::failure(), mlir::OperationState::name, mlir::detail::DenseArrayAttrImpl< T >::parse(), mlir::AsmParser::parseEqual(), mlir::AsmParser::parseOptionalKeyword(), and mlir::success().
|
static |
Definition at line 1227 of file LinalgTransformOps.cpp.
void printOptionalInterchange | ( | OpAsmPrinter & | p, |
ArrayRef< int64_t > | interchangeVals | ||
) |
Definition at line 2671 of file LinalgTransformOps.cpp.
|
static |
Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op.
Definition at line 561 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::create(), mlir::DominanceInfo::dominates(), mlir::RewriterBase::eraseBlock(), mlir::Operation::getResult(), mlir::Value::getUsers(), mlir::Operation::isAncestor(), mlir::RewriterBase::replaceUsesWithIf(), and mlir::OpBuilder::setInsertionPoint().
Referenced by tileAndFuseFirstExtractUse().
|
static |
Find the first "extract" user of producerOp
and tile it right before its use.
The tiled op is fused under the containingOp
. Return this fused op on success or nullptr if anything fails. If tiled op has uses that are dominated by containingOp
, return a new containingOp
with results of the fused op appended to results of the containingOp
or nullptr if there are no dominated uses.
Definition at line 644 of file LinalgTransformOps.cpp.
References DBGS, diag(), mlir::failed(), mlir::Operation::getLoc(), replaceForAllWithNewSignature(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPoint().
|
static |
First, find the first "scf::ForallOp" user of producerOp
and ensure it is exactly the containingOp
, otherwise bail.
Then, find the first "extract" user of the tied block argument and tile it right before its "extract" use. The tiled op is fused under the containingOp
. Return this fused op on success or nullptr if anything fails.
Definition at line 726 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::clone(), DBGS, diag(), mlir::RewriterBase::eraseOp(), mlir::failed(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::OpOperand::getOperandNumber(), mlir::tensor::getOrCreateDestinations(), mlir::Value::getUsers(), mlir::IRMapping::map(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), mlir::Operation::setOperand(), mlir::succeeded(), and mlir::RewriterBase::updateRootInPlace().
|
static |
Attempts to apply the pattern specified as template argument to the given operation.
The pattern is expected to have a returningMatchAndRewrite
function that returns the "main" result or failure. Returns failure if the pattern failed to apply. Extra arguments are forwarded to the pattern constructor.
Definition at line 64 of file LinalgTransformOps.cpp.
|
static |
Assuming that ofr
is an index attr or a transform dialect handle mapped to exactly one op with one index result, return that value.
Definition at line 91 of file LinalgTransformOps.cpp.
|
static |
Definition at line 131 of file LinalgTransformOps.cpp.