23 #include "llvm/ADT/ArrayRef.h"
31 for (
auto resType : resultTypes) {
32 VectorType vecType = dyn_cast<VectorType>(resType);
34 if (!vecType || vecType.getElementType().isIndex())
37 if (vecType.getRank() == 0)
39 unsigned trailingVecDimBitWidth =
40 vecType.getShape().back() * vecType.getElementTypeBitWidth();
41 if (trailingVecDimBitWidth >= targetBitWidth)
55 targetVectorBitWidth(targetVectBitWidth) {}
57 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
61 getTypeConverter()->convertType<VectorType>(constOp.getType());
63 if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
66 "Cannot linearize a constant scalable vector that's not a splat");
72 loc,
"Can't flatten since targetBitWidth <= OpSize");
73 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
77 dstElementsAttr = dstElementsAttr.reshape(resType);
84 unsigned targetVectorBitWidth;
87 struct LinearizeVectorizable final
92 LinearizeVectorizable(
97 targetVectorBitWidth(targetVectBitWidth) {}
103 op->
getLoc(),
"Can't flatten since targetBitWidth <= OpSize");
109 rewriter.
replaceOp(op, (*newOp)->getResults());
114 unsigned targetVectorBitWidth;
128 struct LinearizeVectorExtractStridedSlice final
131 LinearizeVectorExtractStridedSlice(
136 targetVectorBitWidth(targetVectBitWidth) {}
139 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
141 Type dstType = getTypeConverter()->convertType(extractOp.getType());
142 assert(!(extractOp.getVector().getType().isScalable() ||
143 cast<VectorType>(dstType).isScalable()) &&
144 "scalable vectors are not supported.");
147 extractOp,
"Can't flatten since targetBitWidth <= OpSize");
149 ArrayAttr offsets = extractOp.getOffsets();
150 ArrayAttr sizes = extractOp.getSizes();
151 ArrayAttr strides = extractOp.getStrides();
154 extractOp,
"Strided slice with stride != 1 is not supported.");
155 Value srcVector = adaptor.getVector();
164 int64_t extractGranularitySize = 1;
165 int64_t nD = extractOp.getSourceVectorType().getRank();
166 int64_t kD = (int64_t)offsets.size();
169 extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
173 int64_t nExtractedSlices = 1;
175 nExtractedSlices *= cast<IntegerAttr>(size).getInt();
179 for (
int i = kD - 2; i >= 0; --i) {
180 sourceStrides[i] = sourceStrides[i + 1] *
181 extractOp.getSourceVectorType().getShape()[i + 1];
186 extractGranularitySize);
190 for (
int i = kD - 2; i >= 0; --i) {
191 extractedStrides[i] =
192 extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
197 for (int64_t i = 0; i < nExtractedSlices; ++i) {
201 for (int64_t
j = 0;
j < kD; ++
j) {
202 multiDimIndex[
j] = (index / extractedStrides[
j]);
203 index -= multiDimIndex[
j] * extractedStrides[
j];
207 int64_t linearizedIndex = 0;
208 for (int64_t
j = 0;
j < kD; ++
j) {
210 (cast<IntegerAttr>(offsets[
j]).getInt() + multiDimIndex[
j]) *
215 for (int64_t
j = 0;
j < extractGranularitySize; ++
j) {
216 indices[i * extractGranularitySize +
j] = linearizedIndex +
j;
221 extractOp, dstType, srcVector, srcVector,
227 unsigned targetVectorBitWidth;
241 struct LinearizeVectorShuffle final
244 LinearizeVectorShuffle(
249 targetVectorBitWidth(targetVectBitWidth) {}
252 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
254 Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
255 assert(!(shuffleOp.getV1VectorType().isScalable() ||
256 shuffleOp.getV2VectorType().isScalable() ||
257 cast<VectorType>(dstType).isScalable()) &&
258 "scalable vectors are not supported.");
261 shuffleOp,
"Can't flatten since targetBitWidth <= OpSize");
263 Value vec1 = adaptor.getV1();
264 Value vec2 = adaptor.getV2();
265 int shuffleSliceLen = 1;
266 int rank = shuffleOp.getV1().
getType().getRank();
274 for (
unsigned i = 1; i < shape.size(); ++i) {
275 shuffleSliceLen *= shape[i];
283 ArrayAttr mask = shuffleOp.getMask();
284 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
286 for (
auto [i, value] :
289 int64_t v = value.getZExtValue();
290 std::iota(indices.begin() + shuffleSliceLen * i,
291 indices.begin() + shuffleSliceLen * (i + 1),
292 shuffleSliceLen * v);
301 unsigned targetVectorBitWidth;
313 struct LinearizeVectorExtract final
316 LinearizeVectorExtract(
321 targetVectorBitWidth(targetVectBitWidth) {}
323 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
325 Type dstTy = getTypeConverter()->convertType(extractOp.getType());
326 assert(!(extractOp.getVector().getType().isScalable() ||
327 cast<VectorType>(dstTy).isScalable()) &&
328 "scalable vectors are not supported.");
331 extractOp,
"Can't flatten since targetBitWidth <= OpSize");
334 if (extractOp.hasDynamicPosition())
336 "dynamic position is not supported.");
339 int64_t size = extractOp.getVector().getType().getNumElements();
342 int64_t linearizedOffset = 0;
346 linearizedOffset += offsets[i] * size;
350 std::iota(indices.begin(), indices.end(), linearizedOffset);
352 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
359 unsigned targetVectorBitWidth;
367 typeConverter.
addConversion([](VectorType type) -> std::optional<Type> {
377 if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
378 !isa<VectorType>(type))
381 return builder.
create<vector::ShapeCastOp>(loc, type, inputs.front());
387 [=](
Operation *op) -> std::optional<bool> {
388 if ((isa<arith::ConstantOp>(op) ||
397 patterns.
add<LinearizeConstant, LinearizeVectorizable>(
398 typeConverter, patterns.
getContext(), targetBitWidth);
405 [=](vector::ShuffleOp shuffleOp) ->
bool {
407 ? (typeConverter.
isLegal(shuffleOp) &&
408 cast<mlir::VectorType>(shuffleOp.getResult().getType())
412 patterns.
add<LinearizeVectorShuffle, LinearizeVectorExtract,
413 LinearizeVectorExtractStridedSlice>(
414 typeConverter, patterns.
getContext(), targetBitWidth);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth)
Attributes are known-constant values of operations.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for linearizing ND (N >= 2) vector operations to 1D vector shuffle operations.
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.