MLIR  15.0.0git
TosaToLinalgNamed.cpp
Go to the documentation of this file.
1 //===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg named ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
27 
28 #include <numeric>
29 
30 using namespace mlir;
31 using namespace mlir::tosa;
32 
34  Attribute padAttr, OpBuilder &rewriter) {
35  // Input should be padded if necessary.
36  if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
37  return input;
38 
39  ShapedType inputTy = input.getType().cast<ShapedType>();
40  Type inputETy = inputTy.getElementType();
41  auto inputShape = inputTy.getShape();
42 
43  assert((inputShape.size() * 2) == pad.size());
44 
45  SmallVector<int64_t, 4> paddedShape;
47  SmallVector<OpFoldResult, 8> highIndices;
48  for (int i = 0, s = inputShape.size(); i < s; i++) {
49  auto lowPad = pad[i * 2];
50  auto highPad = pad[i * 2 + 1];
51  if (ShapedType::isDynamic(inputShape[i]))
52  paddedShape.push_back(inputShape[i]);
53  else
54  paddedShape.push_back(inputShape[i] + highPad + lowPad);
55  lowIndices.push_back(rewriter.getIndexAttr(lowPad));
56  highIndices.push_back(rewriter.getIndexAttr(highPad));
57  }
58 
59  Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
60 
61  return tensor::createPadScalarOp(RankedTensorType::get(paddedShape, inputETy),
62  input, padValue, lowIndices, highIndices,
63  /*nofold=*/false, loc, rewriter)
64  .result();
65 }
66 
68  ImplicitLocOpBuilder &builder) {
69  return builder.createOrFold<arith::IndexCastOp>(
70  builder.getIndexType(), builder.create<arith::ConstantOp>(attr));
71 }
72 
73 // Calculating the output width/height using the formula:
74 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
75 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
76 static mlir::Value
77 getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr,
78  Attribute padAfterAttr, Value kernelDim, Attribute strideAttr,
79  Attribute dilationAttr, Type inputETy, OpBuilder &rewriter) {
80  ImplicitLocOpBuilder builder(loc, rewriter);
81  auto one = rewriter.create<arith::ConstantOp>(
82  loc, IntegerAttr::get(initDim.getType(), 1));
83  Value padBefore = reifyConstantDim(padBeforeAttr, builder);
84  Value paddedBefore = builder.create<arith::AddIOp>(initDim, padBefore);
85  Value padAfter = reifyConstantDim(padAfterAttr, builder);
86  Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
87 
88  Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
89  Value dilation = reifyConstantDim(dilationAttr, builder);
90  Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
91  Value addOne = builder.create<arith::AddIOp>(dilated, one);
92 
93  Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
94  Value stride = reifyConstantDim(strideAttr, builder);
95  Value divide = builder.create<arith::DivUIOp>(subtract, stride);
96  return builder.create<arith::SubIOp>(divide, one);
97 }
98 
99 // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
101  Location loc, Value input, Value weight, ShapedType resultTy,
102  ArrayAttr padAttr, ArrayAttr strideAttr, ArrayAttr dilationAttr,
103  int64_t weightHDim, int64_t weightWDim, OpBuilder &rewriter) {
104  ShapedType inputTy = input.getType().cast<ShapedType>();
105  Type inputETy = inputTy.getElementType();
106  int64_t inputRank = inputTy.getRank();
107  int64_t heightDim = 1;
108  int64_t weightDim = 2;
109 
110  SmallVector<Value> dynDims;
111  dynDims.resize(resultTy.getRank());
112  for (int i = 0; i < inputRank; i++) {
113  if (inputTy.isDynamicDim(i) && i != heightDim && i != weightDim)
114  dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
115  }
116 
117  // Dynamic input height
118  if (inputTy.isDynamicDim(heightDim)) {
119  Value initHDim =
120  rewriter.create<tensor::DimOp>(loc, input, heightDim).getResult();
121  Value kernelHDim =
122  rewriter.create<tensor::DimOp>(loc, weight, weightHDim).getResult();
123  // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
124  dynDims[heightDim] = getConvOutputDim(
125  loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1], kernelHDim,
126  strideAttr.getValue()[0], dilationAttr.getValue()[0], inputETy,
127  rewriter);
128  }
129 
130  // Dynamic input weight
131  if (inputTy.isDynamicDim(weightDim)) {
132  Value initWDim =
133  rewriter.create<tensor::DimOp>(loc, input, weightDim).getResult();
134  Value kernelWDim =
135  rewriter.create<tensor::DimOp>(loc, weight, weightWDim).getResult();
136  // W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x)
137  dynDims[weightDim] = getConvOutputDim(
138  loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3], kernelWDim,
139  strideAttr.getValue()[1], dilationAttr.getValue()[1], inputETy,
140  rewriter);
141  }
142 
143  SmallVector<Value> filteredDims = condenseValues(dynDims);
144  return filteredDims;
145 }
146 
147 // Creates a map to collapse the last dimension of the Depthwise convolution op
148 // due to a shape mismatch
150  int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
151  OpBuilder &rewriter) {
152  reassociationMap.resize(outputRank);
153  for (int i = 0; i < outputRank; i++) {
154  reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
155  }
156  reassociationMap[outputRank - 1].push_back(
157  rewriter.getAffineDimExpr(outputRank));
158 }
159 
160 namespace {
161 
162 class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
163 public:
166  matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
167  ConversionPatternRewriter &rewriter) const final {
168  Location loc = op->getLoc();
169  Value input = op->getOperand(0);
170  Value weight = op->getOperand(1);
171  Value bias = op->getOperand(2);
172 
173  ShapedType inputTy = input.getType().cast<ShapedType>();
174  ShapedType weightTy = weight.getType().cast<ShapedType>();
175  ShapedType biasTy = bias.getType().cast<ShapedType>();
176  ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
177 
178  Type inputETy = inputTy.getElementType();
179  Type resultETy = resultTy.getElementType();
180 
181  auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
182  auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
183  auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
184  bool isQuantized = op->hasAttr("quantization_info");
185 
186  if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
187  return rewriter.notifyMatchFailure(
188  op, "tosa.conv ops require static shapes for weight and bias");
189 
190  if (inputETy.isUnsignedInteger())
191  return rewriter.notifyMatchFailure(
192  op, "tosa.conv ops does not support unsigned integer input");
193 
195  loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
196  /*weightHDim=*/1, /*weightWDim=*/2, rewriter);
197 
198  auto weightShape = weightTy.getShape();
199 
200  // Apply padding as necessary.
201  Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
202  if (isQuantized) {
203  auto quantizationInfo =
204  op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
205  auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
206 
207  int64_t intMin =
208  APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
209  .getSExtValue();
210  int64_t intMax =
211  APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
212  .getSExtValue();
213 
214  if (iZp < intMin || iZp > intMax)
215  return rewriter.notifyMatchFailure(
216  op, "tosa.conv op quantization has zp outside of input range");
217 
218  zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
219  }
220 
222  pad.resize(2, 0);
223  getValuesFromIntArrayAttribute(padAttr, pad);
224  pad.resize(pad.size() + 2, 0);
225  input = applyPad(loc, input, pad, zeroAttr, rewriter);
226 
227  // Transpose the kernel to match dimension ordering of the linalg
228  // convolution operation.
229  // TODO(suderman): See if this can be efficiently folded - check whether
230  // the input is used anywhere else, if not fold the constant.
231  SmallVector<int64_t> weightPerm{1, 2, 3, 0};
232  SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
233  weightShape[3], weightShape[0]};
234  auto weightPermAttr = DenseIntElementsAttr::get(
235  RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
236  Value weightPermValue =
237  rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
238  Type newWeightTy =
239  RankedTensorType::get(newWeightShape, weightTy.getElementType());
240  weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
241  weightPermValue);
242 
243  Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
244  Value initTensor = rewriter.create<linalg::InitTensorOp>(
245  loc, filteredDims, resultTy.getShape(), resultETy);
246  Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
247  Value zeroTensor = rewriter
248  .create<linalg::FillOp>(loc, ValueRange{zero},
249  ValueRange{initTensor})
250  .result();
251 
252  // Extract the attributes for convolution.
253  llvm::SmallVector<int64_t> stride, dilation;
254  getValuesFromIntArrayAttribute(strideTosaAttr, stride);
255  getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
256 
257  // Create the convolution op.
258  auto strideAttr = DenseIntElementsAttr::get(
259  RankedTensorType::get({2}, rewriter.getI64Type()), stride);
260  auto dilationAttr = DenseIntElementsAttr::get(
261  RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
262 
263  // Create maps for the bias broadcasting
264  SmallVector<AffineMap, 4> indexingMaps;
265  indexingMaps.push_back(AffineMap::get(
266  /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
267  {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
268  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
269  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
270 
271  Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
272  loc, filteredDims, resultTy.getShape(), resultETy);
273 
274  if (isQuantized) {
275  auto quantizationInfo =
276  op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
277  auto iZp = rewriter.getI32IntegerAttr(
278  quantizationInfo.input_zp().getValue().getSExtValue());
279  auto kZp = rewriter.getI32IntegerAttr(
280  quantizationInfo.weight_zp().getValue().getSExtValue());
281 
282  auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
283  auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
284  Value conv =
285  rewriter
286  .create<linalg::Conv2DNhwcHwcfQOp>(
287  loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
288  ValueRange{zeroTensor}, strideAttr, dilationAttr)
289  ->getResult(0);
290 
291  Value result =
292  rewriter
293  .create<linalg::GenericOp>(
294  loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
295  indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
296  [&](OpBuilder &nestedBuilder, Location nestedLoc,
297  ValueRange args) {
298  Value added = nestedBuilder.create<arith::AddIOp>(
299  loc, args[0], args[1]);
300  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
301  })
302  .getResult(0);
303  rewriter.replaceOp(op, result);
304  return success();
305  }
306 
307  Value conv = rewriter
308  .create<linalg::Conv2DNhwcHwcfOp>(
309  loc, resultTy, ValueRange{input, weight},
310  ValueRange{zeroTensor}, strideAttr, dilationAttr)
311  ->getResult(0);
312 
313  Value result =
314  rewriter
315  .create<linalg::GenericOp>(
316  loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
317  indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
318  [&](OpBuilder &nestedBuilder, Location nestedLoc,
319  ValueRange args) {
320  Value added = nestedBuilder.create<arith::AddFOp>(
321  loc, args[0], args[1]);
322  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
323  })
324  .getResult(0);
325 
326  rewriter.replaceOp(op, result);
327  return success();
328  }
329 };
330 
331 class DepthwiseConvConverter
332  : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
333 public:
336  matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
337  ConversionPatternRewriter &rewriter) const final {
338  Location loc = op->getLoc();
339  Value input = op->getOperand(0);
340  Value weight = op->getOperand(1);
341  Value bias = op->getOperand(2);
342 
343  ShapedType inputTy = input.getType().cast<ShapedType>();
344  ShapedType weightTy = weight.getType().cast<ShapedType>();
345  ShapedType biasTy = bias.getType().cast<ShapedType>();
346  ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
347  int64_t resultRank = resultTy.getRank();
348 
349  Type inputETy = inputTy.getElementType();
350  Type resultETy = resultTy.getElementType();
351 
352  auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
353  auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
354  auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
355 
356  if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
357  return rewriter.notifyMatchFailure(
358  op, "tosa.depthwise_conv ops require static shapes");
359 
360  // Compute output dynamic dims
362  loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
363  0, 1, rewriter);
364 
365  bool isQuantized = op->hasAttr("quantization_info");
366  IntegerAttr iZp;
367  IntegerAttr kZp;
368  if (isQuantized) {
369  auto quantizationInfo =
370  op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
371  iZp = rewriter.getI32IntegerAttr(
372  quantizationInfo.input_zp().getValue().getSExtValue());
373  kZp = rewriter.getI32IntegerAttr(
374  quantizationInfo.weight_zp().getValue().getSExtValue());
375  }
376 
377  auto weightShape = weightTy.getShape();
378  auto resultShape = resultTy.getShape();
379 
380  // Apply padding as necessary.
381  Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
382  if (isQuantized) {
383  auto quantizationInfo =
384  op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
385  auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
386 
387  int64_t intMin =
388  APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
389  .getSExtValue();
390  int64_t intMax =
391  APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
392  .getSExtValue();
393 
394  if (iZp < intMin || iZp > intMax)
395  return rewriter.notifyMatchFailure(
396  op, "tosa.depthwise_conv op quantization has zp outside of input "
397  "range");
398 
399  zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
400  }
401 
403  pad.resize(2, 0);
404  getValuesFromIntArrayAttribute(padAttr, pad);
405  pad.resize(pad.size() + 2, 0);
406 
407  input = applyPad(loc, input, pad, zeroAttr, rewriter);
408 
409  // Extract the attributes for convolution.
410  llvm::SmallVector<int64_t> stride, dilation;
411  getValuesFromIntArrayAttribute(strideTosaAttr, stride);
412  getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
413 
414  // Create the convolution op.
415  auto strideAttr = DenseIntElementsAttr::get(
416  RankedTensorType::get({2}, rewriter.getI64Type()), stride);
417  auto dilationAttr = DenseIntElementsAttr::get(
418  RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
419  ShapedType linalgConvTy =
420  RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
421  weightShape[2], weightShape[3]},
422  resultETy);
423 
424  // Broadcast the initial value to the output tensor before convolving.
425  SmallVector<AffineMap, 4> indexingMaps;
426  indexingMaps.push_back(AffineMap::get(
427  /*dimCount=*/resultRank, /*symbolCount=*/0,
428  {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
429  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
430  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
431 
432  Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
433  Value initTensor = rewriter.create<linalg::InitTensorOp>(
434  loc, filteredDims, linalgConvTy.getShape(), resultETy);
435  Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
436  Value zeroTensor = rewriter
437  .create<linalg::FillOp>(loc, ValueRange{zero},
438  ValueRange{initTensor})
439  .result();
440 
441  Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
442  loc, filteredDims, resultTy.getShape(), resultETy);
443  if (!isQuantized) {
444  Value conv = rewriter
445  .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
446  loc, linalgConvTy, ValueRange{input, weight},
447  ValueRange{zeroTensor}, strideAttr, dilationAttr)
448  .getResult(0);
449 
450  SmallVector<ReassociationExprs, 4> reassociationMap;
451  createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
452  Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
453  loc, resultTy, conv, reassociationMap);
454 
455  Value result =
456  rewriter
457  .create<linalg::GenericOp>(
458  loc, resultTy, ValueRange({bias, convReshape}),
459  biasInitTensor, indexingMaps,
460  getNParallelLoopsAttrs(resultRank),
461  [&](OpBuilder &nestedBuilder, Location nestedLoc,
462  ValueRange args) {
463  Value added = nestedBuilder.create<arith::AddFOp>(
464  loc, args[0], args[1]);
465  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
466  })
467  .getResult(0);
468  rewriter.replaceOp(op, result);
469  } else {
470  auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
471  auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
472  Value conv =
473  rewriter
474  .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
475  loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
476  ValueRange{zeroTensor}, strideAttr, dilationAttr)
477  .getResult(0);
478  SmallVector<ReassociationExprs, 4> reassociationMap;
479  createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
480  Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
481  loc, resultTy, conv, reassociationMap);
482  Value result =
483  rewriter
484  .create<linalg::GenericOp>(
485  loc, resultTy, ValueRange({bias, convReshape}),
486  biasInitTensor, indexingMaps,
487  getNParallelLoopsAttrs(resultRank),
488  [&](OpBuilder &nestedBuilder, Location nestedLoc,
489  ValueRange args) {
490  Value added = nestedBuilder.create<arith::AddIOp>(
491  loc, args[0], args[1]);
492  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
493  })
494  .getResult(0);
495  rewriter.replaceOp(op, result);
496  }
497  return success();
498  }
499 };
500 
501 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
502 public:
505  matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
506  ConversionPatternRewriter &rewriter) const final {
507  Location loc = op.getLoc();
508 
509  auto outputTy = op.getType().cast<ShapedType>();
510  auto outputElementTy = outputTy.getElementType();
511 
512  auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>();
513  auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>();
514 
515  SmallVector<Value> dynDims;
516  dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
517 
518  if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
519  dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
520  }
521 
522  if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
523  dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
524  }
525 
526  if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
527  dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
528  }
529 
530  SmallVector<Value> filteredDims = condenseValues(dynDims);
531 
532  auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
533  Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
534  auto initTensor = rewriter.create<linalg::InitTensorOp>(
535  loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
536  Value zeroTensor = rewriter
537  .create<linalg::FillOp>(loc, ValueRange{zero},
538  ValueRange{initTensor})
539  .result();
540  if (!op.quantization_info()) {
541  rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
542  op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
543  ValueRange{zeroTensor});
544  return success();
545  }
546 
547  auto quantizationInfo = op.quantization_info().getValue();
548  auto aZp = rewriter.create<arith::ConstantOp>(
549  loc, rewriter.getI32IntegerAttr(
550  quantizationInfo.a_zp().getValue().getSExtValue()));
551  auto bZp = rewriter.create<arith::ConstantOp>(
552  loc, rewriter.getI32IntegerAttr(
553  quantizationInfo.b_zp().getValue().getSExtValue()));
554  rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
555  op, TypeRange{op.getType()},
556  ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
557 
558  return success();
559  }
560 };
561 
562 class FullyConnectedConverter
563  : public OpConversionPattern<tosa::FullyConnectedOp> {
564 public:
567  matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
568  ConversionPatternRewriter &rewriter) const final {
569  Location loc = op.getLoc();
570  auto outputTy = op.getType().cast<ShapedType>();
571  auto input = op.input();
572  auto inputTy = input.getType().cast<ShapedType>();
573 
574  auto bias = op.bias();
575 
576  auto weight = op.weight();
577  auto weightTy = weight.getType().cast<ShapedType>();
578  auto weightShape = weightTy.getShape();
579 
580  auto outputETy = outputTy.getElementType();
581 
582  SmallVector<Value> dynDims;
583  dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
584 
585  if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
586  dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
587  }
588 
589  if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
590  dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
591  }
592 
593  SmallVector<Value> filteredDims = condenseValues(dynDims);
594 
595  // Creating maps for the output of MatMul and the bias
596  SmallVector<AffineMap, 4> indexingMaps;
597 
598  // Broadcast the bias.
599  indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
600  {rewriter.getAffineDimExpr(1)},
601  rewriter.getContext()));
602 
603  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
604  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
605 
606  auto initTensor = rewriter.create<linalg::InitTensorOp>(
607  loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
608 
609  // When quantized, the input elemeny type is not the same as the output
610  Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
611  Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
612  Value zeroTensor = rewriter
613  .create<linalg::FillOp>(loc, ValueRange{zero},
614  ValueRange{initTensor})
615  .result();
616 
617  SmallVector<int64_t> permutation{1, 0};
618  auto permutationAttr = DenseIntElementsAttr::get(
619  RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
620  Value permutationValue =
621  rewriter.create<arith::ConstantOp>(loc, permutationAttr);
622 
623  SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
624  Type newWeightTy =
625  RankedTensorType::get(newWeightShape, weightTy.getElementType());
626 
627  Value transposedWeight = rewriter.create<tosa::TransposeOp>(
628  loc, newWeightTy, weight, permutationValue);
629 
630  auto biasInitTensor =
631  rewriter
632  .create<linalg::InitTensorOp>(loc, filteredDims,
633  outputTy.getShape(), outputETy)
634  ->getResults();
635 
636  if (!op.quantization_info()) {
637  Value matmul = rewriter
638  .create<linalg::MatmulOp>(
639  loc, TypeRange{op.getType()},
640  ValueRange{input, transposedWeight}, zeroTensor)
641  ->getResult(0);
642 
643  Value result =
644  rewriter
645  .create<linalg::GenericOp>(
646  loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
647  indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
648  [&](OpBuilder &nestedBuilder, Location nestedLoc,
649  ValueRange args) {
650  Value added = nestedBuilder.create<arith::AddFOp>(
651  loc, args[0], args[1]);
652  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
653  })
654  .getResult(0);
655  rewriter.replaceOp(op, result);
656  return success();
657  }
658 
659  auto quantizationInfo = op.quantization_info().getValue();
660  auto inputZp = rewriter.create<arith::ConstantOp>(
661  loc, rewriter.getI32IntegerAttr(
662  quantizationInfo.input_zp().getValue().getSExtValue()));
663  auto outputZp = rewriter.create<arith::ConstantOp>(
664  loc, rewriter.getI32IntegerAttr(
665  quantizationInfo.weight_zp().getValue().getSExtValue()));
666  Value matmul =
667  rewriter
668  .create<linalg::QuantizedMatmulOp>(
669  loc, TypeRange{op.getType()},
670  ValueRange{input, transposedWeight, inputZp, outputZp},
671  zeroTensor)
672  ->getResult(0);
673  Value result =
674  rewriter
675  .create<linalg::GenericOp>(
676  loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
677  indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
678  [&](OpBuilder &nestedBuilder, Location nestedLoc,
679  ValueRange args) {
680  Value added = nestedBuilder.create<arith::AddIOp>(
681  loc, args[0], args[1]);
682  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
683  })
684  .getResult(0);
685  rewriter.replaceOp(op, result);
686  return success();
687  }
688 };
689 
690 class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
691 public:
693 
694  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
695  PatternRewriter &rewriter) const final {
696  Location loc = op.getLoc();
697  Value input = op.input();
698  ShapedType inputTy = input.getType().cast<ShapedType>();
699 
700  ShapedType resultTy = op.getType().template cast<ShapedType>();
701  Type resultETy = inputTy.getElementType();
702 
703  auto dynamicDimsOr =
704  checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
705  if (!dynamicDimsOr.hasValue())
706  return failure();
707  SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
708 
709  // Determine what the initial value needs to be for the max pool op.
710  Attribute initialAttr;
711  if (resultETy.isF32())
712  initialAttr = rewriter.getFloatAttr(
713  resultETy,
714  APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
715  true));
716 
717  if (resultETy.isa<IntegerType>())
718  initialAttr = rewriter.getIntegerAttr(
719  resultETy,
720  APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
721 
722  if (!initialAttr)
723  return rewriter.notifyMatchFailure(
724  op, "Unsupported initial value for tosa.maxpool_2d op");
725 
726  // Apply padding as necessary.
728  pad.resize(2, 0);
729  getValuesFromIntArrayAttribute(op.pad(), pad);
730  pad.resize(pad.size() + 2, 0);
731  Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
732 
733  Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
734 
735  SmallVector<int64_t> kernel, stride;
736  getValuesFromIntArrayAttribute(op.kernel(), kernel);
737  getValuesFromIntArrayAttribute(op.stride(), stride);
738 
739  Attribute strideAttr = rewriter.getI64VectorAttr(stride);
740  Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
741 
742  // Create the linalg op that performs pooling.
743  Value initTensor = rewriter.create<linalg::InitTensorOp>(
744  loc, dynamicDims, resultTy.getShape(), resultTy.getElementType());
745 
746  Value filledInitTensor =
747  rewriter
748  .create<linalg::FillOp>(loc, ValueRange{initialValue},
749  ValueRange{initTensor})
750  .result();
751 
752  Value fakeWindowDims =
753  rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
754 
755  rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
756  op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
757  filledInitTensor, strideAttr, dilationAttr);
758  return success();
759  }
760 };
761 
762 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
763 public:
765 
766  LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
767  PatternRewriter &rewriter) const final {
768  Location loc = op.getLoc();
769  Value input = op.input();
770  ShapedType inputTy = input.getType().cast<ShapedType>();
771  Type inElementTy = inputTy.getElementType();
772 
773  ShapedType resultTy = op.getType().template cast<ShapedType>();
774  Type resultETy = op.getType().cast<ShapedType>().getElementType();
775 
776  Type accETy =
777  inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
778  ShapedType accTy = resultTy.clone(accETy);
779 
780  auto dynamicDimsOr =
781  checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
782  if (!dynamicDimsOr.hasValue())
783  return failure();
784  SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
785 
786  // Apply padding as necessary.
788  pad.resize(2, 0);
789  getValuesFromIntArrayAttribute(op.pad(), pad);
790  pad.resize(pad.size() + 2, 0);
791  Attribute padAttr = rewriter.getZeroAttr(inElementTy);
792  Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
793 
794  Attribute initialAttr = rewriter.getZeroAttr(accETy);
795  Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
796 
797  SmallVector<int64_t> kernel, stride;
798  getValuesFromIntArrayAttribute(op.kernel(), kernel);
799  getValuesFromIntArrayAttribute(op.stride(), stride);
800 
801  Attribute strideAttr = rewriter.getI64VectorAttr(stride);
802  Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
803 
804  // Create the linalg op that performs pooling.
805  Value poolInitTensor = rewriter.create<linalg::InitTensorOp>(
806  loc, dynamicDims, accTy.getShape(), accETy);
807 
808  Value filledInitTensor =
809  rewriter
810  .create<linalg::FillOp>(loc, ValueRange{initialValue},
811  ValueRange{poolInitTensor})
812  .result();
813 
814  Value fakeWindowDims =
815  rewriter.create<linalg::InitTensorOp>(loc, kernel, accETy);
816 
817  // Sum across the pooled region.
818  Value poolingOp = rewriter
819  .create<linalg::PoolingNhwcSumOp>(
820  loc, ArrayRef<Type>{accTy},
821  ValueRange{paddedInput, fakeWindowDims},
822  filledInitTensor, strideAttr, dilationAttr)
823  .getResult(0);
824 
825  // Normalize the summed value by the number of elements grouped in each
826  // pool.
827  auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
828  auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
829 
830  Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
831  loc, dynamicDims, resultTy.getShape(), resultETy);
832 
833  auto genericOp = rewriter.create<linalg::GenericOp>(
834  loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
835  ValueRange{genericInitTensor},
836  ArrayRef<AffineMap>({affineMap, affineMap}),
837  getNParallelLoopsAttrs(resultTy.getRank()),
838  [&](OpBuilder &b, Location loc, ValueRange args) {
839  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
840  auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
841  auto iH = rewriter.create<arith::ConstantIndexOp>(
842  loc, poolingOpTy.getDimSize(1) - 1);
843  auto iW = rewriter.create<arith::ConstantIndexOp>(
844  loc, poolingOpTy.getDimSize(2) - 1);
845 
846  // Compute the indices from either end.
847  auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
848  auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
849  auto y1 = rewriter.create<arith::SubIOp>(loc, iH, y0);
850  auto x1 = rewriter.create<arith::SubIOp>(loc, iW, x0);
851 
852  // Determines what the portion of valid input is covered by the
853  // kernel.
854  auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
855  if (pad == 0)
856  return v;
857 
858  auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
859  Value dx = rewriter.create<arith::SubIOp>(loc, x, padVal);
860 
861  Value cmp = rewriter.create<arith::CmpIOp>(
862  loc, arith::CmpIPredicate::slt, dx, zero);
863  Value offset = rewriter.create<arith::SelectOp>(loc, cmp, dx, zero);
864  return rewriter.create<arith::AddIOp>(loc, v, offset)->getResult(0);
865  };
866 
867  // Compute the vertical component of coverage.
868  auto kH0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[0]);
869  auto kH1 = padFn(kH0, y0, pad[2]);
870  auto kH2 = padFn(kH1, y1, pad[3]);
871  auto kHCmp = rewriter.create<arith::CmpIOp>(
872  loc, arith::CmpIPredicate::slt, kH2, one);
873  auto kH3 = rewriter.create<arith::SelectOp>(loc, kHCmp, one, kH2);
874 
875  // compute the horizontal component of coverage.
876  auto kW0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[1]);
877  auto kW1 = padFn(kW0, x0, pad[4]);
878  auto kW2 = padFn(kW1, x1, pad[5]);
879  auto kWCmp = rewriter.create<arith::CmpIOp>(
880  loc, arith::CmpIPredicate::slt, kW2, one);
881  auto kW3 = rewriter.create<arith::SelectOp>(loc, kWCmp, one, kW2);
882 
883  // Compute the total number of elements and normalize.
884  Value count = rewriter.create<arith::MulIOp>(loc, kH3, kW3);
885  auto countI = rewriter.create<arith::IndexCastOp>(
886  loc, rewriter.getI32Type(), count);
887 
888  // Divide by the number of summed values. For floats this is just
889  // a div however for quantized values input normalization had
890  // to be applied.
891  Value poolVal = args[0];
892  if (accETy.isa<FloatType>()) {
893  auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, countI);
894  poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
895  ->getResult(0);
896  } else {
897 
898  // If we have quantization information we need to apply an offset
899  // for the input zp value.
900  if (op.quantization_info()) {
901  auto quantizationInfo = op.quantization_info().getValue();
902  auto inputZp = rewriter.create<arith::ConstantOp>(
903  loc, quantizationInfo.input_zp());
904  Value offset =
905  rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
906  poolVal =
907  rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
908  }
909 
910  // Compute the multiplier and shift values for the quantization
911  // normalization. Preferably we would want to compute more bits
912  // however 32-bits should be enough for compute. Honestly we
913  // should probably straight divide.
914  int64_t numerator = ((1 << 30) + 1);
915  int64_t shift = 30;
916 
917  Value numeratorVal = rewriter.create<arith::ConstantOp>(
918  loc, rewriter.getI32IntegerAttr(numerator));
919  Value multiplierVal =
920  rewriter
921  .create<arith::DivUIOp>(loc, rewriter.getI32Type(),
922  numeratorVal, countI)
923  .getResult();
924  Value shiftVal = rewriter.create<arith::ConstantOp>(
925  loc, rewriter.getI8IntegerAttr(shift));
926 
927  auto scaled =
928  rewriter
929  .create<tosa::ApplyScaleOp>(
930  loc, rewriter.getI32Type(), poolVal, multiplierVal,
931  shiftVal, rewriter.getBoolAttr(false))
932  .getResult();
933 
934  // If we have quantization information we need to apply output
935  // zeropoint.
936  if (op.quantization_info()) {
937  auto quantizationInfo = op.quantization_info().getValue();
938  auto outputZp = rewriter.create<arith::ConstantOp>(
939  loc, quantizationInfo.output_zp());
940  scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
941  .getResult();
942  }
943 
944  // Apply Clip.
945  int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
946 
947  auto min = rewriter.create<arith::ConstantIntOp>(
948  loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
949  accETy);
950  auto max = rewriter.create<arith::ConstantIntOp>(
951  loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
952  accETy);
953  auto clamp = clampHelper<arith::CmpIOp>(
954  loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter);
955 
956  poolVal = clamp;
957  // Convert type.
958  if (resultETy != clamp.getType()) {
959  poolVal =
960  rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
961  }
962  }
963 
964  rewriter.create<linalg::YieldOp>(loc, poolVal);
965  });
966 
967  rewriter.replaceOp(op, genericOp.getResult(0));
968  return success();
969  }
970 };
971 
972 } // namespace
973 
975  RewritePatternSet *patterns) {
976  patterns->add<
977  // clang-format off
978  ConvConverter,
979  DepthwiseConvConverter,
980  MatMulConverter,
981  MaxPool2dConverter,
982  AvgPool2dConverter,
983  FullyConnectedConverter>(patterns->getContext());
984  // clang-format on
985 }
Include the generated interface declarations.
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:115
U cast() const
Definition: Location.h:67
MLIRContext * getContext() const
Definition: Builders.h:54
U cast() const
Definition: Attributes.h:130
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:308
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:42
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:687
void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector< T > &arrayValues)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
SmallVector< StringRef > getNParallelLoopsAttrs(unsigned nParallelLoops)
Optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
static mlir::Value getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr, Attribute padAfterAttr, Value kernelDim, Attribute strideAttr, Attribute dilationAttr, Type inputETy, OpBuilder &rewriter)
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:166
Attributes are known-constant values of operations.
Definition: Attributes.h:24
PadOp createPadScalarOp(Type type, Value source, Value pad, ArrayRef< OpFoldResult > low, ArrayRef< OpFoldResult > high, bool nofold, Location loc, OpBuilder &builder)
Definition: Utils.cpp:21
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void createOrFold(llvm::SmallVectorImpl< Value > &results, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static SmallVector< int64_t, 8 > subtract(ArrayRef< int64_t > vecA, ArrayRef< int64_t > vecB)
static mlir::Value reifyConstantDim(Attribute attr, ImplicitLocOpBuilder &builder)
void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg named operations.
IntegerType getI64Type()
Definition: Builders.cpp:56
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayAttr padAttr, ArrayAttr strideAttr, ArrayAttr dilationAttr, int64_t weightHDim, int64_t weightWDim, OpBuilder &rewriter)
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
ImplicitLocOpBuilder maintains a &#39;current location&#39;, allowing use of the create<> method without spec...
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:79
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:285
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
bool isa() const
Definition: Types.h:234
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, Attribute padAttr, OpBuilder &rewriter)
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:378
MLIRContext * getContext() const
IntegerType getI32Type()
Definition: Builders.cpp:54
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
U cast() const
Definition: Types.h:250