22 #include "llvm/ADT/ArrayRef.h"
30 for (
auto resType : resultTypes) {
31 VectorType vecType = dyn_cast<VectorType>(resType);
33 if (!vecType || vecType.getElementType().isIndex())
36 if (vecType.getRank() == 0)
38 unsigned trailingVecDimBitWidth =
39 vecType.getShape().back() * vecType.getElementTypeBitWidth();
40 if (trailingVecDimBitWidth >= targetBitWidth)
47 VectorType vecType = dyn_cast<VectorType>(t);
49 if (!vecType || vecType.getElementType().isIndex())
52 if (vecType.getRank() == 0)
54 unsigned trailingVecDimBitWidth =
55 vecType.getShape().back() * vecType.getElementTypeBitWidth();
56 return trailingVecDimBitWidth <= targetBitWidth;
67 targetVectorBitWidth(targetVectBitWidth) {}
69 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
73 getTypeConverter()->convertType<VectorType>(constOp.getType());
75 if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
78 "Cannot linearize a constant scalable vector that's not a splat");
84 loc,
"Can't flatten since targetBitWidth <= OpSize");
85 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
89 dstElementsAttr = dstElementsAttr.reshape(resType);
96 unsigned targetVectorBitWidth;
99 struct LinearizeVectorizable final
104 LinearizeVectorizable(
109 targetVectorBitWidth(targetVectBitWidth) {}
115 op->
getLoc(),
"Can't flatten since targetBitWidth <= OpSize");
116 FailureOr<Operation *> newOp =
121 rewriter.
replaceOp(op, (*newOp)->getResults());
126 unsigned targetVectorBitWidth;
140 struct LinearizeVectorExtractStridedSlice final
143 LinearizeVectorExtractStridedSlice(
148 targetVectorBitWidth(targetVectBitWidth) {}
151 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
154 getTypeConverter()->convertType<VectorType>(extractOp.getType());
155 assert(dstType &&
"vector type destination expected.");
156 if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
158 "scalable vectors are not supported.");
161 extractOp,
"Can't flatten since targetBitWidth <= OpSize");
163 ArrayAttr offsets = extractOp.getOffsets();
164 ArrayAttr sizes = extractOp.getSizes();
165 ArrayAttr strides = extractOp.getStrides();
168 extractOp,
"Strided slice with stride != 1 is not supported.");
169 Value srcVector = adaptor.getVector();
178 int64_t extractGranularitySize = 1;
179 int64_t nD = extractOp.getSourceVectorType().getRank();
180 int64_t kD = (int64_t)offsets.size();
183 extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
187 int64_t nExtractedSlices = 1;
189 nExtractedSlices *= cast<IntegerAttr>(size).getInt();
193 for (
int i = kD - 2; i >= 0; --i) {
194 sourceStrides[i] = sourceStrides[i + 1] *
195 extractOp.getSourceVectorType().getShape()[i + 1];
200 extractGranularitySize);
204 for (
int i = kD - 2; i >= 0; --i) {
205 extractedStrides[i] =
206 extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
211 for (int64_t i = 0; i < nExtractedSlices; ++i) {
215 for (int64_t
j = 0;
j < kD; ++
j) {
216 multiDimIndex[
j] = (index / extractedStrides[
j]);
217 index -= multiDimIndex[
j] * extractedStrides[
j];
221 int64_t linearizedIndex = 0;
222 for (int64_t
j = 0;
j < kD; ++
j) {
224 (cast<IntegerAttr>(offsets[
j]).getInt() + multiDimIndex[
j]) *
229 for (int64_t
j = 0;
j < extractGranularitySize; ++
j) {
230 indices[i * extractGranularitySize +
j] = linearizedIndex +
j;
235 extractOp, dstType, srcVector, srcVector, indices);
240 unsigned targetVectorBitWidth;
254 struct LinearizeVectorShuffle final
257 LinearizeVectorShuffle(
262 targetVectorBitWidth(targetVectBitWidth) {}
265 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
268 getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
269 assert(dstType &&
"vector type destination expected.");
272 assert(!(shuffleOp.getV1VectorType().isScalable() ||
273 shuffleOp.getV2VectorType().isScalable() ||
274 dstType.isScalable()) &&
275 "scalable vectors are not supported.");
278 shuffleOp,
"Can't flatten since targetBitWidth <= OpSize");
280 Value vec1 = adaptor.getV1();
281 Value vec2 = adaptor.getV2();
282 int shuffleSliceLen = 1;
283 int rank = shuffleOp.getV1().
getType().getRank();
291 for (
unsigned i = 1; i < shape.size(); ++i) {
292 shuffleSliceLen *= shape[i];
301 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
304 std::iota(indices.begin() + shuffleSliceLen * i,
305 indices.begin() + shuffleSliceLen * (i + 1),
306 shuffleSliceLen * value);
315 unsigned targetVectorBitWidth;
327 struct LinearizeVectorExtract final
330 LinearizeVectorExtract(
335 targetVectorBitWidth(targetVectBitWidth) {}
337 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
339 Type dstTy = getTypeConverter()->convertType(extractOp.getType());
342 "expected n-D vector type.");
344 if (extractOp.getVector().getType().isScalable() ||
345 cast<VectorType>(dstTy).isScalable())
347 "scalable vectors are not supported.");
350 extractOp,
"Can't flatten since targetBitWidth <= OpSize");
353 if (extractOp.hasDynamicPosition())
355 "dynamic position is not supported.");
358 int64_t size = extractOp.getVector().getType().getNumElements();
361 int64_t linearizedOffset = 0;
365 linearizedOffset += offsets[i] * size;
369 std::iota(indices.begin(), indices.end(), linearizedOffset);
371 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
377 unsigned targetVectorBitWidth;
390 struct LinearizeVectorInsert final
393 LinearizeVectorInsert(
398 targetVectorBitWidth(targetVectBitWidth) {}
400 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
402 VectorType dstTy = getTypeConverter()->convertType<VectorType>(
403 insertOp.getDestVectorType());
404 assert(dstTy &&
"vector type destination expected.");
405 if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
407 "scalable vectors are not supported.");
410 targetVectorBitWidth))
412 insertOp,
"Can't flatten since targetBitWidth < OpSize");
415 if (insertOp.hasDynamicPosition())
417 "dynamic position is not supported.");
418 auto srcTy = insertOp.getSourceType();
419 auto srcAsVec = dyn_cast<VectorType>(srcTy);
420 uint64_t srcSize = 0;
422 srcSize = srcAsVec.getNumElements();
425 "scalars are not supported.");
428 auto dstShape = insertOp.getDestVectorType().getShape();
429 const auto dstSize = insertOp.getDestVectorType().getNumElements();
430 auto dstSizeForOffsets = dstSize;
433 int64_t linearizedOffset = 0;
434 auto offsetsNd = insertOp.getStaticPosition();
436 dstSizeForOffsets /= dstShape[dim];
437 linearizedOffset += offset * dstSizeForOffsets;
441 auto origValsUntil = indices.begin();
442 std::advance(origValsUntil, linearizedOffset);
443 std::iota(indices.begin(), origValsUntil,
445 auto newValsUntil = origValsUntil;
446 std::advance(newValsUntil, srcSize);
447 std::iota(origValsUntil, newValsUntil,
449 std::iota(newValsUntil, indices.end(),
450 linearizedOffset + srcSize);
454 insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
460 unsigned targetVectorBitWidth;
468 typeConverter.
addConversion([](VectorType type) -> std::optional<Type> {
478 if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
479 !isa<VectorType>(type))
482 return builder.
create<vector::ShapeCastOp>(loc, type, inputs.front());
488 [=](
Operation *op) -> std::optional<bool> {
489 if ((isa<arith::ConstantOp>(op) ||
498 patterns.
add<LinearizeConstant, LinearizeVectorizable>(
499 typeConverter, patterns.
getContext(), targetBitWidth);
506 [=](vector::ShuffleOp shuffleOp) ->
bool {
508 ? (typeConverter.
isLegal(shuffleOp) &&
509 cast<mlir::VectorType>(shuffleOp.getResult().getType())
513 patterns.
add<LinearizeVectorShuffle, LinearizeVectorExtract,
514 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
515 typeConverter, patterns.
getContext(), targetBitWidth);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth)
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth)
Attributes are known-constant values of operations.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for linearizing ND (N >= 2) vector operations to 1D vector shuffle operations.
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.