15 #include "llvm/ADT/TypeSwitch.h"
27 LogicalResult matchAndRewrite(GenerateOp generateOp,
30 llvm::cast<RankedTensorType>(generateOp.getResult().getType());
31 if (!tensorType.hasStaticShape())
34 cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
43 tensorType, generateOp->getLoc());
61 int64_t srcLinearIndex) {
62 assert(inputShape.size() == outputStrides.size());
64 int64_t dstLinearIndex = 0;
66 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
73 auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
75 srcLinearIndex = quotient;
79 dstLinearIndex += outputStrides[dim] * remainder;
82 return dstLinearIndex;
85 template <
typename ElemType,
typename AttrType>
90 if (failed(inputValues))
93 auto oldShape = input.
getType().getShape();
97 llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
98 [](std::tuple<int64_t, int64_t, int64_t>
pack) {
99 auto [old, low, high] =
pack;
100 return old + low + high;
116 int64_t startingOffset =
linearize(padLow, outputStrides);
121 auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
122 values[outputIndex + startingOffset] = inputValue;
126 auto newType = input.
getType().clone(newShape);
134 return constantOp ? constantOp->
getResult(0) :
nullptr;
143 LogicalResult matchAndRewrite(PadOp padTensorOp,
145 if (padTensorOp.getNofold())
147 padTensorOp,
"refusing to fold nofold pad operation");
150 RankedTensorType resultType = padTensorOp.getResult().
getType();
156 Value paddingValue = padTensorOp.getConstantPaddingValue();
162 "unable to get constant value");
169 if (!lowPad || !highPad)
171 "unable to extract constant padding");
175 if (!controlFn(&padTensorOp.getSourceMutable()))
177 "not folding due to cost function");
179 Location loc = padTensorOp.getLoc();
184 .Case([&](FloatAttr floatAttr) {
185 return constantFoldPadOp<llvm::APFloat>(
186 rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
188 .Case([&](IntegerAttr integerAttr) {
189 return constantFoldPadOp<llvm::APInt>(
190 rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
196 "tensor type not supported");
198 if (newOp.
getType() != resultType)
199 newOp = rewriter.
create<tensor::CastOp>(loc, resultType, newOp);
215 patterns.
add<PadOpToConstant>(patterns.
getContext(), controlFn);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
FailureOr< iterator_range_impl< ElementIterator< T > > > tryGetValues() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
std::function< bool(OpOperand *)> ControlFoldFn
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...