MLIR  19.0.0git
VectorPattern.cpp
Go to the documentation of this file.
1 //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 using namespace mlir;
13 
14 // For >1-D vector types, extracts the necessary information to iterate over all
15 // 1-D subvectors in the underlying llrepresentation of the n-D vector
16 // Iterates on the llvm array type until we hit a non-array type (which is
17 // asserted to be an llvm vector type).
20  const LLVMTypeConverter &converter) {
21  assert(vectorType.getRank() > 1 && "expected >1D vector type");
22  NDVectorTypeInfo info;
23  info.llvmNDVectorTy = converter.convertType(vectorType);
25  info.llvmNDVectorTy = nullptr;
26  return info;
27  }
28  info.arraySizes.reserve(vectorType.getRank() - 1);
29  auto llvmTy = info.llvmNDVectorTy;
30  while (isa<LLVM::LLVMArrayType>(llvmTy)) {
31  info.arraySizes.push_back(
32  cast<LLVM::LLVMArrayType>(llvmTy).getNumElements());
33  llvmTy = cast<LLVM::LLVMArrayType>(llvmTy).getElementType();
34  }
35  if (!LLVM::isCompatibleVectorType(llvmTy))
36  return info;
37  info.llvm1DVectorTy = llvmTy;
38  return info;
39 }
40 
41 // Express `linearIndex` in terms of coordinates of `basis`.
42 // Returns the empty vector when linearIndex is out of the range [0, P] where
43 // P is the product of all the basis coordinates.
44 //
45 // Prerequisites:
46 // Basis is an array of nonnegative integers (signed type inherited from
47 // vector shape type).
49  unsigned linearIndex) {
51  res.reserve(basis.size());
52  for (unsigned basisElement : llvm::reverse(basis)) {
53  res.push_back(linearIndex % basisElement);
54  linearIndex = linearIndex / basisElement;
55  }
56  if (linearIndex > 0)
57  return {};
58  std::reverse(res.begin(), res.end());
59  return res;
60 }
61 
62 // Iterate of linear index, convert to coords space and insert splatted 1-D
63 // vector in each position.
65  OpBuilder &builder,
66  function_ref<void(ArrayRef<int64_t>)> fun) {
67  unsigned ub = 1;
68  for (auto s : info.arraySizes)
69  ub *= s;
70  for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71  auto coords = getCoordinates(info.arraySizes, linearIndex);
72  // Linear index is out of bounds, we are done.
73  if (coords.empty())
74  break;
75  assert(coords.size() == info.arraySizes.size());
76  fun(coords);
77  }
78 }
79 
81  Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
82  std::function<Value(Type, ValueRange)> createOperand,
83  ConversionPatternRewriter &rewriter) {
84  auto resultNDVectorType = cast<VectorType>(op->getResult(0).getType());
85  auto resultTypeInfo =
86  extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
87  auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
88  auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
89  auto loc = op->getLoc();
90  Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
91  nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
92  // For this unrolled `position` corresponding to the `linearIndex`^th
93  // element, extract operand vectors
94  SmallVector<Value, 4> extractedOperands;
95  for (const auto &operand : llvm::enumerate(operands)) {
96  extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
97  loc, operand.value(), position));
98  }
99  Value newVal = createOperand(result1DVectorTy, extractedOperands);
100  desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position);
101  });
102  rewriter.replaceOp(op, desc);
103  return success();
104 }
105 
108  ValueRange operands,
109  ArrayRef<NamedAttribute> targetAttrs,
110  const LLVMTypeConverter &typeConverter,
111  ConversionPatternRewriter &rewriter) {
112  assert(!operands.empty());
113 
114  // Cannot convert ops if their operands are not of LLVM type.
115  if (!llvm::all_of(operands.getTypes(), isCompatibleType))
116  return failure();
117 
118  auto llvmNDVectorTy = operands[0].getType();
119  if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
120  return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
121  rewriter);
122 
123  auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
124  ValueRange operands) {
125  return rewriter
126  .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
127  llvm1DVectorTy, targetAttrs)
128  ->getResult(0);
129  };
130 
131  return handleMultidimensionalVectors(op, operands, typeConverter, callback,
132  rewriter);
133 }
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1541
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:335
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:880
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:862
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t, 4 > arraySizes
Definition: VectorPattern.h:27
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26