15#include "llvm/ADT/TypeSwitch.h"
25 using OpRewritePattern<GenerateOp>::OpRewritePattern;
27 LogicalResult matchAndRewrite(GenerateOp generateOp,
28 PatternRewriter &rewriter)
const override {
30 llvm::cast<RankedTensorType>(generateOp.getResult().getType());
31 if (!tensorType.hasStaticShape())
34 cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
38 Operation *constantOp =
43 tensorType, generateOp->getLoc());
62 assert(inputShape.size() == outputStrides.size());
66 for (
int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
73 auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
75 srcLinearIndex = quotient;
79 dstLinearIndex += outputStrides[dim] * remainder;
82 return dstLinearIndex;
85template <
typename ElemType,
typename AttrType>
93 auto oldShape = input.
getType().getShape();
97 llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
98 [](std::tuple<int64_t, int64_t, int64_t> pack) {
99 auto [old, low, high] =
pack;
100 return old + low + high;
120 for (
auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
121 auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
122 values[outputIndex + startingOffset] = inputValue;
126 auto newType = input.
getType().clone(newShape);
134 return constantOp ? constantOp->
getResult(0) :
nullptr;
139 PadOpToConstant(MLIRContext *context,
const ControlFoldFn &controlFn,
140 PatternBenefit benefit = 1)
141 : OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
143 LogicalResult matchAndRewrite(PadOp padTensorOp,
144 PatternRewriter &rewriter)
const override {
145 if (padTensorOp.getNofold())
147 padTensorOp,
"refusing to fold nofold pad operation");
150 RankedTensorType resultType = padTensorOp.getResult().getType();
152 DenseElementsAttr inputAttr =
nullptr;
156 Value paddingValue = padTensorOp.getConstantPaddingValue();
159 Attribute paddingAttr =
nullptr;
162 "unable to get constant value");
169 if (!lowPad || !highPad)
171 "unable to extract constant padding");
175 if (!controlFn(&padTensorOp.getSourceMutable()))
177 "not folding due to cost function");
179 Location loc = padTensorOp.getLoc();
183 llvm::TypeSwitch<Attribute, Value>(paddingAttr)
184 .Case([&](FloatAttr floatAttr) {
185 return constantFoldPadOp<llvm::APFloat>(
186 rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
188 .Case([&](IntegerAttr integerAttr) {
189 return constantFoldPadOp<llvm::APInt>(
190 rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
196 "tensor type not supported");
198 if (newOp.
getType() != resultType)
199 newOp = tensor::CastOp::create(rewriter, loc, resultType, newOp);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
FailureOr< iterator_range_impl< ElementIterator< T > > > tryGetValues() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
std::function< bool(OpOperand *)> ControlFoldFn
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...