22 #include "llvm/ADT/ArrayRef.h"
29 static FailureOr<Attribute>
33 if (
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
34 if (resType.isScalable() && !isa<SplatElementsAttr>(value))
37 "Cannot linearize a constant scalable vector that's not a splat");
39 return dstElementsAttr.reshape(resType);
42 if (
auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
50 struct LinearizeConstantLike final
67 assert(resType &&
"expected 1-D vector type");
74 FailureOr<Attribute> newValue =
79 FailureOr<Operation *> convertResult =
81 if (failed(convertResult))
85 newOp->
setAttr(attrName, *newValue);
91 struct LinearizeVectorizable final
102 FailureOr<Operation *> newOp =
107 rewriter.
replaceOp(op, (*newOp)->getResults());
112 template <
typename TOp>
113 static bool stridesAllOne(TOp op) {
115 std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116 std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117 "expected vector.extract_strided_slice or vector.insert_strided_slice");
118 ArrayAttr strides = op.getStrides();
123 static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
127 ints.reserve(attrs.size());
128 for (
auto attr : attrs) {
129 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130 ints.push_back(intAttr.getInt());
158 assert((large.size() >= small.size()) &&
159 "rank of 'large' cannot be lower than rank of 'small'");
160 assert((large.size() >= offsets.size()) &&
161 "rank of 'large' cannot be lower than the number of offsets");
162 unsigned delta = large.size() - small.size();
163 unsigned nOffsets = offsets.size();
164 auto getSmall = [&](int64_t i) -> int64_t {
165 return i >= delta ? small[i - delta] : 1;
167 auto getOffset = [&](int64_t i) -> int64_t {
168 return i < nOffsets ? offsets[i] : 0;
176 for (
int i = large.size() - 1; i >= 0; --i) {
177 int64_t currentSize = indices.size();
178 int64_t smallSize = getSmall(i);
179 int64_t nextSize = currentSize * smallSize;
181 int64_t *base = nextIndices.begin();
182 int64_t offset = getOffset(i) * stride;
183 for (
int j = 0;
j < smallSize; ++
j) {
184 for (
int k = 0; k < currentSize; ++k) {
185 base[k] = indices[k] + offset;
191 indices = std::move(nextIndices);
215 struct LinearizeVectorExtractStridedSlice final
218 LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter,
224 matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
228 VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
229 extractStridedSliceOp.getType());
230 assert(flatOutputType &&
"vector type expected");
234 if (!stridesAllOne(extractStridedSliceOp)) {
236 extractStridedSliceOp,
237 "extract_strided_slice with strides != 1 not supported");
240 FailureOr<SmallVector<int64_t>> offsets =
241 intsFromArrayAttr(extractStridedSliceOp.getOffsets());
242 if (failed(offsets)) {
244 "failed to get integer offsets");
248 extractStridedSliceOp.getSourceVectorType().getShape();
253 outputShape, inputShape, offsets.value());
255 Value srcVector = adaptor.getVector();
257 extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
286 struct LinearizeVectorInsertStridedSlice final
289 LinearizeVectorInsertStridedSlice(
const TypeConverter &typeConverter,
295 matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
301 if (!stridesAllOne(insertStridedSliceOp)) {
303 insertStridedSliceOp,
304 "insert_strided_slice with strides != 1 not supported");
307 VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
310 VectorType outputType = insertStridedSliceOp.getType();
312 int64_t nOutputElements = outputType.getNumElements();
314 FailureOr<SmallVector<int64_t>> offsets =
315 intsFromArrayAttr(insertStridedSliceOp.getOffsets());
316 if (failed(offsets)) {
318 "failed to get integer offsets");
321 inputShape, outputShape, offsets.value());
324 std::iota(indices.begin(), indices.end(), 0);
326 indices[sliceIndex] = index + nOutputElements;
329 Value flatToStore = adaptor.getValueToStore();
330 Value flatDest = adaptor.getDest();
333 flatToStore, indices);
349 struct LinearizeVectorShuffle final
357 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
360 getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
361 assert(dstType &&
"vector type destination expected.");
363 Value vec1 = adaptor.getV1();
364 Value vec2 = adaptor.getV2();
365 int shuffleSliceLen = 1;
366 int rank = shuffleOp.getV1().
getType().getRank();
374 for (
unsigned i = 1; i < shape.size(); ++i) {
375 shuffleSliceLen *= shape[i];
384 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
387 std::iota(indices.begin() + shuffleSliceLen * i,
388 indices.begin() + shuffleSliceLen * (i + 1),
389 shuffleSliceLen * value);
424 struct LinearizeVectorExtract final
431 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
433 Type dstTy = getTypeConverter()->convertType(extractOp.getType());
434 assert(dstTy &&
"expected 1-D vector type");
437 if (extractOp.hasDynamicPosition())
439 "dynamic position is not supported.");
442 int64_t size = extractOp.getVector().getType().getNumElements();
445 int64_t linearizedOffset = 0;
449 linearizedOffset += offsets[i] * size;
452 Value srcVector = adaptor.getVector();
453 if (!isa<VectorType>(extractOp.getType())) {
456 extractOp.getLoc(), srcVector, linearizedOffset);
464 std::iota(indices.begin(), indices.end(), linearizedOffset);
502 struct LinearizeVectorInsert final
509 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
511 VectorType dstTy = getTypeConverter()->convertType<VectorType>(
512 insertOp.getDestVectorType());
513 assert(dstTy &&
"vector type destination expected.");
516 if (insertOp.hasDynamicPosition())
518 "dynamic position is not supported.");
519 auto srcTy = insertOp.getValueToStoreType();
520 auto srcAsVec = dyn_cast<VectorType>(srcTy);
521 uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1;
523 auto dstShape = insertOp.getDestVectorType().getShape();
524 const auto dstSize = insertOp.getDestVectorType().getNumElements();
525 auto dstSizeForOffsets = dstSize;
528 int64_t linearizedOffset = 0;
529 auto offsetsNd = insertOp.getStaticPosition();
531 dstSizeForOffsets /= dstShape[dim];
532 linearizedOffset += offset * dstSizeForOffsets;
536 Value valueToStore = adaptor.getValueToStore();
538 if (!isa<VectorType>(valueToStore.
getType())) {
541 loc, valueToStore, adaptor.getDest(), linearizedOffset);
548 auto *origValsUntil = indices.begin();
549 std::advance(origValsUntil, linearizedOffset);
552 std::iota(indices.begin(), origValsUntil, 0);
553 auto *newValsUntil = origValsUntil;
554 std::advance(newValsUntil, srcSize);
556 std::iota(origValsUntil, newValsUntil, dstSize);
558 std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize);
561 loc, dstTy, adaptor.getDest(), valueToStore, indices);
576 struct LinearizeVectorBitCast final
583 matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
585 auto resType = getTypeConverter()->convertType(castOp.getType());
586 assert(resType &&
"expected 1-D vector type");
588 adaptor.getSource());
589 return mlir::success();
599 struct LinearizeVectorSplat final
608 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
610 auto dstTy = getTypeConverter()->convertType(splatOp.getType());
630 struct LinearizeVectorCreateMask final
634 LinearizeVectorCreateMask(
const TypeConverter &typeConverter,
639 matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
641 Location loc = createMaskOp.getLoc();
642 VectorType srcTy = createMaskOp.getType();
643 auto srcShape = srcTy.getShape();
644 if (srcShape.size() != 2)
646 "only 2D mask is supported.");
648 if (srcShape[0] != 1)
650 createMaskOp,
"only unit outer dimension is supported.");
652 auto dstTy = getTypeConverter()->convertType(srcTy);
660 auto firstOperand = adaptor.getOperands().front();
662 auto isNonZero = rewriter.
createOrFold<mlir::arith::CmpIOp>(
663 loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
664 auto isNonZeroIndex = rewriter.
createOrFold<mlir::arith::IndexCastOp>(
666 auto secondOperand = adaptor.getOperands().back();
667 auto maskSize = rewriter.
createOrFold<mlir::arith::AndIOp>(
668 loc, rewriter.
getIndexType(), isNonZeroIndex, secondOperand);
671 rewriter.
create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
672 rewriter.
replaceOp(createMaskOp, newMask);
693 matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
695 VectorType vecTy = loadOp.getType();
699 auto shape = vecTy.getShape();
700 auto scalableDims = vecTy.getScalableDims();
703 if (!llvm::all_of(shape.drop_back(1), [](
auto d) { return d == 1; }))
705 "only vector<1x1x...xN> supported");
707 if (llvm::any_of(scalableDims.drop_back(1), [](
bool s) { return s; }))
709 "only innermost dim may be scalable");
711 auto linearTy = typeConverter->
convertType<VectorType>(vecTy);
713 auto newLoad = rewriter.
create<vector::LoadOp>(
714 loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
715 rewriter.
replaceOp(loadOp, newLoad.getResult());
731 struct LinearizeVectorStore final
739 matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
741 VectorType vecTy = storeOp.getValueToStore().getType();
745 auto shape = vecTy.getShape();
746 auto scalableDims = vecTy.getScalableDims();
749 if (!llvm::all_of(shape.drop_back(1), [](
auto d) { return d == 1; }))
751 "only vector<1x1x...xN> supported");
753 if (llvm::any_of(scalableDims.drop_back(1), [](
bool s) { return s; }))
755 "only innermost dim may be scalable");
758 storeOp, adaptor.getValueToStore(), adaptor.getBase(),
759 adaptor.getIndices());
772 StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
774 bool supported = (opDialect == vectorDialect) ||
784 .Case<vector::ShapeCastOp>([&](
auto) {
return false; })
794 .Case<vector::ExtractStridedSliceOp>(
795 [&](vector::ExtractStridedSliceOp extractOp) {
796 return !extractOp.getType().isScalable();
798 .Case<vector::InsertStridedSliceOp>(
799 [&](vector::InsertStridedSliceOp insertOp) {
800 return !insertOp.getType().isScalable();
802 .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
803 return !insertOp.getType().isScalable();
805 .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
806 return !extractOp.getSourceVectorType().isScalable();
808 .Default([&](
auto) {
return true; });
811 void mlir::vector::populateForVectorLinearize(
TypeConverter &typeConverter,
814 auto convertType = [](
Type type) -> std::optional<Type> {
815 VectorType vectorType = dyn_cast<VectorType>(type);
819 VectorType linearizedType =
821 vectorType.getElementType(), vectorType.isScalable());
822 return linearizedType;
828 if (inputs.size() != 1)
831 Value value = inputs.front();
832 if (!isa<VectorType>(type) || !isa<VectorType>(value.
getType()))
835 return builder.
create<vector::ShapeCastOp>(loc, type, value);
841 [=](
Operation *op) -> std::optional<bool> {
846 return typeConverter.
isLegal(op);
850 void mlir::vector::populateVectorLinearizeBasePatterns(
854 .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
855 LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
856 LinearizeVectorStore>(typeConverter,
patterns.getContext());
859 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
862 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
863 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
864 LinearizeVectorInsertStridedSlice>(typeConverter,
static FailureOr< Attribute > linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value)
static bool isLinearizable(Operation *op)
This method defines the set of operations that are linearizable, and hence that are considered illega...
Attributes are known-constant values of operations.
StringAttr getStringAttr(const Twine &bytes)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
Include the generated interface declarations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.