23 #include "llvm/ADT/ArrayRef.h"
31 for (
auto resType : resultTypes) {
32 VectorType vecType = dyn_cast<VectorType>(resType);
34 if (!vecType || vecType.getElementType().isIndex())
37 if (vecType.getRank() == 0)
39 unsigned trailingVecDimBitWidth =
40 vecType.getShape().back() * vecType.getElementTypeBitWidth();
41 if (trailingVecDimBitWidth >= targetBitWidth)
48 VectorType vecType = dyn_cast<VectorType>(t);
50 if (!vecType || vecType.getElementType().isIndex())
53 if (vecType.getRank() == 0)
55 unsigned trailingVecDimBitWidth =
56 vecType.getShape().back() * vecType.getElementTypeBitWidth();
57 return trailingVecDimBitWidth <= targetBitWidth;
60 static FailureOr<Attribute>
63 if (
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
64 if (resType.isScalable() && !isa<SplatElementsAttr>(value))
67 "Cannot linearize a constant scalable vector that's not a splat");
69 return dstElementsAttr.reshape(resType);
72 if (
auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
79 struct LinearizeConstantLike final
83 LinearizeConstantLike(
88 targetVectorBitWidth(targetVectBitWidth) {}
105 loc,
"Can't flatten since targetBitWidth <= OpSize");
112 FailureOr<Attribute> newValue =
114 if (failed(newValue))
117 FailureOr<Operation *> convertResult =
119 if (failed(convertResult))
123 newOp->
setAttr(attrName, *newValue);
129 unsigned targetVectorBitWidth;
132 struct LinearizeVectorizable final
137 LinearizeVectorizable(
142 targetVectorBitWidth(targetVectBitWidth) {}
148 op->
getLoc(),
"Can't flatten since targetBitWidth <= OpSize");
149 FailureOr<Operation *> newOp =
154 rewriter.
replaceOp(op, (*newOp)->getResults());
159 unsigned targetVectorBitWidth;
173 struct LinearizeVectorExtractStridedSlice final
176 LinearizeVectorExtractStridedSlice(
181 targetVectorBitWidth(targetVectBitWidth) {}
184 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
187 getTypeConverter()->convertType<VectorType>(extractOp.getType());
188 assert(dstType &&
"vector type destination expected.");
189 if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
191 "scalable vectors are not supported.");
194 extractOp,
"Can't flatten since targetBitWidth <= OpSize");
196 ArrayAttr offsets = extractOp.getOffsets();
197 ArrayAttr sizes = extractOp.getSizes();
198 ArrayAttr strides = extractOp.getStrides();
201 extractOp,
"Strided slice with stride != 1 is not supported.");
202 Value srcVector = adaptor.getVector();
211 int64_t extractGranularitySize = 1;
212 int64_t nD = extractOp.getSourceVectorType().getRank();
213 int64_t kD = (int64_t)offsets.size();
216 extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
220 int64_t nExtractedSlices = 1;
222 nExtractedSlices *= cast<IntegerAttr>(size).getInt();
226 for (
int i = kD - 2; i >= 0; --i) {
227 sourceStrides[i] = sourceStrides[i + 1] *
228 extractOp.getSourceVectorType().getShape()[i + 1];
233 extractGranularitySize);
237 for (
int i = kD - 2; i >= 0; --i) {
238 extractedStrides[i] =
239 extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
244 for (int64_t i = 0; i < nExtractedSlices; ++i) {
248 for (int64_t
j = 0;
j < kD; ++
j) {
249 multiDimIndex[
j] = (index / extractedStrides[
j]);
250 index -= multiDimIndex[
j] * extractedStrides[
j];
254 int64_t linearizedIndex = 0;
255 for (int64_t
j = 0;
j < kD; ++
j) {
257 (cast<IntegerAttr>(offsets[
j]).getInt() + multiDimIndex[
j]) *
262 for (int64_t
j = 0;
j < extractGranularitySize; ++
j) {
263 indices[i * extractGranularitySize +
j] = linearizedIndex +
j;
268 extractOp, dstType, srcVector, srcVector, indices);
273 unsigned targetVectorBitWidth;
287 struct LinearizeVectorShuffle final
290 LinearizeVectorShuffle(
295 targetVectorBitWidth(targetVectBitWidth) {}
298 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
301 getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
302 assert(dstType &&
"vector type destination expected.");
305 assert(!(shuffleOp.getV1VectorType().isScalable() ||
306 shuffleOp.getV2VectorType().isScalable() ||
307 dstType.isScalable()) &&
308 "scalable vectors are not supported.");
311 shuffleOp,
"Can't flatten since targetBitWidth <= OpSize");
313 Value vec1 = adaptor.getV1();
314 Value vec2 = adaptor.getV2();
315 int shuffleSliceLen = 1;
316 int rank = shuffleOp.getV1().
getType().getRank();
324 for (
unsigned i = 1; i < shape.size(); ++i) {
325 shuffleSliceLen *= shape[i];
334 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
337 std::iota(indices.begin() + shuffleSliceLen * i,
338 indices.begin() + shuffleSliceLen * (i + 1),
339 shuffleSliceLen * value);
348 unsigned targetVectorBitWidth;
360 struct LinearizeVectorExtract final
363 LinearizeVectorExtract(
368 targetVectorBitWidth(targetVectBitWidth) {}
370 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
372 Type dstTy = getTypeConverter()->convertType(extractOp.getType());
375 "expected n-D vector type.");
377 if (extractOp.getVector().getType().isScalable() ||
378 cast<VectorType>(dstTy).isScalable())
380 "scalable vectors are not supported.");
383 extractOp,
"Can't flatten since targetBitWidth <= OpSize");
386 if (extractOp.hasDynamicPosition())
388 "dynamic position is not supported.");
391 int64_t size = extractOp.getVector().getType().getNumElements();
394 int64_t linearizedOffset = 0;
398 linearizedOffset += offsets[i] * size;
402 std::iota(indices.begin(), indices.end(), linearizedOffset);
404 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
410 unsigned targetVectorBitWidth;
423 struct LinearizeVectorInsert final
426 LinearizeVectorInsert(
431 targetVectorBitWidth(targetVectBitWidth) {}
433 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
435 VectorType dstTy = getTypeConverter()->convertType<VectorType>(
436 insertOp.getDestVectorType());
437 assert(dstTy &&
"vector type destination expected.");
438 if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
440 "scalable vectors are not supported.");
443 targetVectorBitWidth))
445 insertOp,
"Can't flatten since targetBitWidth < OpSize");
448 if (insertOp.hasDynamicPosition())
450 "dynamic position is not supported.");
451 auto srcTy = insertOp.getSourceType();
452 auto srcAsVec = dyn_cast<VectorType>(srcTy);
453 uint64_t srcSize = 0;
455 srcSize = srcAsVec.getNumElements();
458 "scalars are not supported.");
461 auto dstShape = insertOp.getDestVectorType().getShape();
462 const auto dstSize = insertOp.getDestVectorType().getNumElements();
463 auto dstSizeForOffsets = dstSize;
466 int64_t linearizedOffset = 0;
467 auto offsetsNd = insertOp.getStaticPosition();
469 dstSizeForOffsets /= dstShape[dim];
470 linearizedOffset += offset * dstSizeForOffsets;
474 auto origValsUntil = indices.begin();
475 std::advance(origValsUntil, linearizedOffset);
476 std::iota(indices.begin(), origValsUntil,
478 auto newValsUntil = origValsUntil;
479 std::advance(newValsUntil, srcSize);
480 std::iota(origValsUntil, newValsUntil,
482 std::iota(newValsUntil, indices.end(),
483 linearizedOffset + srcSize);
487 insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
493 unsigned targetVectorBitWidth;
504 struct LinearizeVectorBitCast final
507 LinearizeVectorBitCast(
512 targetVectorBitWidth(targetVectBitWidth) {}
514 matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
517 auto resType = getTypeConverter()->convertType(castOp.getType());
523 loc,
"Can't flatten since targetBitWidth <= OpSize");
526 adaptor.getSource());
527 return mlir::success();
531 unsigned targetVectorBitWidth;
540 typeConverter.
addConversion([](VectorType type) -> std::optional<Type> {
550 if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
551 !isa<VectorType>(type))
554 return builder.
create<vector::ShapeCastOp>(loc, type, inputs.front());
559 [=](
Operation *op) -> std::optional<bool> {
560 if ((isa<vector::BitCastOp>(op) ||
570 patterns.add<LinearizeConstantLike, LinearizeVectorizable,
571 LinearizeVectorBitCast>(typeConverter,
patterns.getContext(),
579 [=](vector::ShuffleOp shuffleOp) ->
bool {
581 ? (typeConverter.
isLegal(shuffleOp) &&
582 cast<mlir::VectorType>(shuffleOp.getResult().getType())
586 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
587 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
588 typeConverter,
patterns.getContext(), targetBitWidth);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth)
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth)
static FailureOr< Attribute > linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value)
Attributes are known-constant values of operations.
StringAttr getStringAttr(const Twine &bytes)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for linearizing ND (N >= 2) vector operations to 1D vector shuffle operations.
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.