MLIR 23.0.0git
SPIRVTosaOps.cpp
Go to the documentation of this file.
1//===- SPIRVTosaOps.cpp - MLIR SPIR-V Tosa operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Tosa operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/Support/InterleavedRange.h"
16#include <algorithm>
17
18namespace mlir::spirv {
19
20//===----------------------------------------------------------------------===//
21// SPIRV Tosa Custom formatters
22//===----------------------------------------------------------------------===//
23
27 auto f = [&]() {
28 int32_t value;
29 ParseResult r = parser.parseInteger(value);
30 elements.push_back(value);
31 return r;
32 };
33 if (parser.parseCommaSeparatedList(
35 "parsing values in integer list attribute")) {
36 return failure();
37 }
38
39 auto i32Type = IntegerType::get(parser.getContext(), 32);
40 auto type = TensorArmType::get(
41 ArrayRef{static_cast<int64_t>(elements.size())}, i32Type);
42 attr = DenseIntElementsAttr::get(type, elements);
43 return success();
44}
45
48 printer << llvm::interleaved_array(
49 llvm::map_range(attr.getValues<APInt>(),
50 [](const APInt &a) { return a.getSExtValue(); }));
51}
52
53//===----------------------------------------------------------------------===//
54// SPIRV Tosa Custom verifiers
55//===----------------------------------------------------------------------===//
56
57namespace {
58
59int64_t getIntValue(DenseIntElementsAttr attr, size_t idx) {
60 return attr.getValues<APInt>()[idx].getSExtValue();
61}
62
63LogicalResult verifyPool2DOutputDim(Operation *op, int64_t inputSize,
64 int64_t outputSize, int64_t kernelSize,
65 int64_t strideSize, int64_t padBefore,
66 int64_t padAfter, StringRef dimName,
67 StringRef dimAxis, StringRef padBeforeName,
68 StringRef padAfterName) {
69 if (ShapedType::isDynamic(inputSize))
70 return success();
71
72 const int64_t numerator = inputSize + padBefore + padAfter - kernelSize;
73 if (numerator % strideSize != 0)
74 return op->emitOpError("expected input_")
75 << dimName << " + pad_" << padBeforeName << " + pad_" << padAfterName
76 << " - kernel_" << dimAxis << " to be wholly divisible by stride_"
77 << dimAxis << ", got (" << inputSize << " + " << padBefore << " + "
78 << padAfter << " - " << kernelSize << ") / " << strideSize;
79
80 const int64_t calculatedOutput = numerator / strideSize + 1;
81 if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
82 return op->emitOpError("failed to verify that shapes of input and output "
83 "must satisfy [N,IH,IW,C] and [N,OH,OW,C], with "
84 "OH = ((IH + pad_top + pad_bottom - kernel_y) / "
85 "stride_y) + 1 and OW = ((IW + pad_left + "
86 "pad_right - kernel_x) / stride_x) + 1");
87
88 return success();
89}
90
91LogicalResult verifyPool2DOp(Operation *op, DenseIntElementsAttr kernel,
92 DenseIntElementsAttr stride,
93 DenseIntElementsAttr pad, TensorArmType inputType,
94 TensorArmType outputType) {
95
96 if (!inputType.hasRank() || !outputType.hasRank())
97 return success();
98
99 if (failed(verifyPool2DOutputDim(
100 op, inputType.getDimSize(1), outputType.getDimSize(1),
101 getIntValue(kernel, 0), getIntValue(stride, 0), getIntValue(pad, 0),
102 getIntValue(pad, 1), "height", "y", "top", "bottom")))
103 return failure();
104
105 if (failed(verifyPool2DOutputDim(
106 op, inputType.getDimSize(2), outputType.getDimSize(2),
107 getIntValue(kernel, 1), getIntValue(stride, 1), getIntValue(pad, 2),
108 getIntValue(pad, 3), "width", "x", "left", "right")))
109 return failure();
110
111 return success();
112}
113
114LogicalResult verifyConvolutionOutputDim(int64_t inputSize, int64_t kernelSize,
115 int64_t outputSize, int64_t padBefore,
116 int64_t padAfter, int64_t strideSize,
117 int64_t dilationSize) {
118 if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
119 return success();
120
121 const int64_t numerator =
122 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilationSize;
123 if (numerator % strideSize != 0)
124 return failure();
125
126 const int64_t calculatedOutput = numerator / strideSize + 1;
127 if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
128 return failure();
129
130 return success();
131}
132
133LogicalResult
134verifyTransposeConvolutionOutputDim(int64_t inputSize, int64_t kernelSize,
135 int64_t outputSize, int64_t padBefore,
136 int64_t padAfter, int64_t strideSize) {
137 if (ShapedType::isDynamic(inputSize) || ShapedType::isDynamic(kernelSize))
138 return success();
139
140 const int64_t calculatedOutput =
141 (inputSize - 1) * strideSize + padBefore + padAfter + kernelSize;
142 if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput)
143 return failure();
144
145 return success();
146}
147
148LogicalResult verifyConv2DOutputShape(Operation *op, DenseIntElementsAttr pad,
149 DenseIntElementsAttr stride,
150 DenseIntElementsAttr dilation,
151 TensorArmType inputType,
152 TensorArmType weightType,
153 TensorArmType outputType) {
154 constexpr StringLiteral errorMessage =
155 "failed to verify that shapes of input, weight, and output must satisfy "
156 "[N,IH,IW,*], [*,KH,KW,*], [N,OH,OW,*], with OH = ((IH - 1 + pad_top + "
157 "pad_bottom - (KH - 1) * dilation_y) / stride_y) + 1 and OW = ((IW - 1 "
158 "+ pad_left + pad_right - (KW - 1) * dilation_x) / stride_x) + 1";
159 if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
160 return success();
161
162 if (failed(verifyConvolutionOutputDim(
163 inputType.getDimSize(1), weightType.getDimSize(1),
164 outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
165 getIntValue(stride, 0), getIntValue(dilation, 0))))
166 return op->emitOpError(errorMessage);
167
168 if (failed(verifyConvolutionOutputDim(
169 inputType.getDimSize(2), weightType.getDimSize(2),
170 outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
171 getIntValue(stride, 1), getIntValue(dilation, 1))))
172 return op->emitOpError(errorMessage);
173
174 return success();
175}
176
177LogicalResult verifyConv3DOutputShape(Operation *op, DenseIntElementsAttr pad,
178 DenseIntElementsAttr stride,
179 DenseIntElementsAttr dilation,
180 TensorArmType inputType,
181 TensorArmType weightType,
182 TensorArmType outputType) {
183 constexpr StringLiteral errorMessage =
184 "failed to verify that shapes of input, weight, and output must satisfy "
185 "[N,ID,IH,IW,*], [*,KD,KH,KW,*], [N,OD,OH,OW,*], with OD = ((ID - 1 + "
186 "pad_front + pad_back - (KD - 1) * dilation_d) / stride_d) + 1, OH = "
187 "((IH - 1 + pad_top + pad_bottom - (KH - 1) * dilation_y) / stride_y) "
188 "+ 1 and OW = ((IW - 1 + pad_left + pad_right - (KW - 1) * dilation_x) "
189 "/ stride_x) + 1";
190 if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
191 return success();
192
193 if (failed(verifyConvolutionOutputDim(
194 inputType.getDimSize(1), weightType.getDimSize(1),
195 outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
196 getIntValue(stride, 0), getIntValue(dilation, 0))))
197 return op->emitOpError(errorMessage);
198
199 if (failed(verifyConvolutionOutputDim(
200 inputType.getDimSize(2), weightType.getDimSize(2),
201 outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
202 getIntValue(stride, 1), getIntValue(dilation, 1))))
203 return op->emitOpError(errorMessage);
204
205 if (failed(verifyConvolutionOutputDim(
206 inputType.getDimSize(3), weightType.getDimSize(3),
207 outputType.getDimSize(3), getIntValue(pad, 4), getIntValue(pad, 5),
208 getIntValue(stride, 2), getIntValue(dilation, 2))))
209 return op->emitOpError(errorMessage);
210
211 return success();
212}
213
214LogicalResult verifyDepthwiseConv2DOutputShape(
215 Operation *op, DenseIntElementsAttr pad, DenseIntElementsAttr stride,
216 DenseIntElementsAttr dilation, TensorArmType inputType,
217 TensorArmType weightType, TensorArmType outputType) {
218 constexpr StringLiteral errorMessage =
219 "failed to verify that shapes of input, weight, and output must satisfy "
220 "[N,IH,IW,*], [KH,KW,*,*], [N,OH,OW,*], with OH = ((IH - 1 + pad_top + "
221 "pad_bottom - (KH - 1) * dilation_y) / stride_y) + 1 and OW = ((IW - 1 "
222 "+ pad_left + pad_right - (KW - 1) * dilation_x) / stride_x) + 1";
223 if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
224 return success();
225
226 if (failed(verifyConvolutionOutputDim(
227 inputType.getDimSize(1), weightType.getDimSize(0),
228 outputType.getDimSize(1), getIntValue(pad, 0), getIntValue(pad, 1),
229 getIntValue(stride, 0), getIntValue(dilation, 0))))
230 return op->emitOpError(errorMessage);
231
232 if (failed(verifyConvolutionOutputDim(
233 inputType.getDimSize(2), weightType.getDimSize(1),
234 outputType.getDimSize(2), getIntValue(pad, 2), getIntValue(pad, 3),
235 getIntValue(stride, 1), getIntValue(dilation, 1))))
236 return op->emitOpError(errorMessage);
237
238 return success();
239}
240
241LogicalResult verifyTransposeConv2DOutputShape(Operation *op,
242 DenseIntElementsAttr outPad,
243 DenseIntElementsAttr stride,
244 TensorArmType inputType,
245 TensorArmType weightType,
246 TensorArmType outputType) {
247 constexpr StringLiteral errorMessage =
248 "failed to verify that shapes of input, weight, and output must satisfy "
249 "[N,IH,IW,*], [*,KH,KW,*], [N,OH,OW,*], with OH = (IH - 1) * stride_y + "
250 "out_pad_top + out_pad_bottom + KH and OW = (IW - 1) * stride_x + "
251 "out_pad_left + out_pad_right + KW";
252 if (!inputType.hasRank() || !weightType.hasRank() || !outputType.hasRank())
253 return success();
254
255 const int64_t kernelHeight = weightType.getDimSize(1);
256 if (ShapedType::isStatic(kernelHeight) &&
257 (getIntValue(outPad, 0) <= -kernelHeight ||
258 getIntValue(outPad, 1) <= -kernelHeight))
259 return op->emitOpError("expected out_pad_top and out_pad_bottom to be > "
260 "-KH");
261
262 const int64_t kernelWidth = weightType.getDimSize(2);
263 if (ShapedType::isStatic(kernelWidth) &&
264 (getIntValue(outPad, 2) <= -kernelWidth ||
265 getIntValue(outPad, 3) <= -kernelWidth))
266 return op->emitOpError("expected out_pad_left and out_pad_right to be > "
267 "-KW");
268
269 if (failed(verifyTransposeConvolutionOutputDim(
270 inputType.getDimSize(1), kernelHeight, outputType.getDimSize(1),
271 getIntValue(outPad, 0), getIntValue(outPad, 1),
272 getIntValue(stride, 0))))
273 return op->emitOpError(errorMessage);
274
275 if (failed(verifyTransposeConvolutionOutputDim(
276 inputType.getDimSize(2), kernelWidth, outputType.getDimSize(2),
277 getIntValue(outPad, 2), getIntValue(outPad, 3),
278 getIntValue(stride, 1))))
279 return op->emitOpError(errorMessage);
280
281 return success();
282}
283
284LogicalResult verifyConcatOutputShape(Operation *op, TypeRange inputTypes,
285 TensorArmType outputType, int32_t axis) {
286 constexpr StringLiteral errorMessage =
287 "failed to verify that shape of output must match the concatenation of "
288 "input1 along axis";
289 if (!outputType.hasRank())
290 return success();
291
292 if (llvm::any_of(inputTypes, [](Type type) {
293 return !cast<TensorArmType>(type).hasRank();
294 }))
295 return success();
296
297 for (int64_t dim = 0, rank = outputType.getRank(); dim < rank; ++dim) {
298 int64_t outputDim = outputType.getDimSize(dim);
299 if (ShapedType::isDynamic(outputDim))
300 continue;
301
302 if (dim != axis) {
303 for (Type type : inputTypes) {
304 int64_t inputDim = cast<TensorArmType>(type).getDimSize(dim);
305 if (ShapedType::isStatic(inputDim) && inputDim != outputDim)
306 return op->emitOpError(errorMessage);
307 }
308 continue;
309 }
310
311 int64_t concatDim = 0;
312 for (Type type : inputTypes) {
313 int64_t inputDim = cast<TensorArmType>(type).getDimSize(dim);
314 if (ShapedType::isDynamic(inputDim)) {
315 concatDim = ShapedType::kDynamic;
316 break;
317 }
318 concatDim += inputDim;
319 }
320
321 if (ShapedType::isStatic(concatDim) && concatDim != outputDim)
322 return op->emitOpError(errorMessage);
323 }
324
325 return success();
326}
327
328} // namespace
329
330LogicalResult TosaAvgPool2DOp::verify() {
331 return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(),
332 getInputType(), getResultType());
333}
334
335LogicalResult TosaConv2DOp::verify() {
336 return verifyConv2DOutputShape(getOperation(), getPad(), getStride(),
337 getDilation(), getInputType(), getWeightType(),
338 getResultType());
339}
340
341LogicalResult TosaConv3DOp::verify() {
342 return verifyConv3DOutputShape(getOperation(), getPad(), getStride(),
343 getDilation(), getInputType(), getWeightType(),
344 getResultType());
345}
346
347LogicalResult TosaDepthwiseConv2DOp::verify() {
348 return verifyDepthwiseConv2DOutputShape(getOperation(), getPad(), getStride(),
349 getDilation(), getInputType(),
350 getWeightType(), getResultType());
351}
352
353LogicalResult TosaMaxPool2DOp::verify() {
354 return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(),
355 getInputType(), getResultType());
356}
357
358LogicalResult TosaTransposeConv2DOp::verify() {
359 return verifyTransposeConv2DOutputShape(getOperation(), getOutPad(),
360 getStride(), getInputType(),
361 getWeightType(), getResultType());
362}
363
364LogicalResult TosaConcatOp::verify() {
365 return verifyConcatOutputShape(getOperation(), getInput1Types(),
366 getResultType(), getAxis());
367}
368
369LogicalResult TosaSelectOp::verify() {
370 TensorArmType condType = getConditionType();
371 TensorArmType trueValType = getTrueValueType();
372 TensorArmType falseValType = getFalseValueType();
373 TensorArmType resultType = getResultType();
374
375 if (llvm::any_of(ArrayRef<TensorArmType>{condType, trueValType, falseValType,
376 resultType},
377 [](TensorArmType type) { return !type.hasRank(); }))
378 return success();
379
380 ArrayRef<int64_t> condShape = condType.getShape();
381 ArrayRef<int64_t> trueValShape = trueValType.getShape();
382 ArrayRef<int64_t> falseValShape = falseValType.getShape();
383 ArrayRef<int64_t> resultShape = resultType.getShape();
384
385 if (!llvm::all_equal({condShape.size(), trueValShape.size(),
386 falseValShape.size(), resultShape.size()})) {
387 // The AllRanksMatch predicate enforces that all ranks are equal.
388 // This is just an extra safe guard for the code coming after that
389 // assumes that all ranks are equal.
390 return failure();
391 }
392
393 for (auto dims :
394 llvm::zip_equal(condShape, trueValShape, falseValShape, resultShape)) {
395 auto [condDim, trueValDim, falseValDim, resultDim] = dims;
396
397 if (llvm::any_of(
398 ArrayRef<int64_t>{condDim, trueValDim, falseValDim, resultDim},
399 [](int64_t dim) { return ShapedType::isDynamic(dim); })) {
400 continue;
401 }
402
403 auto isPairBroadcastable = [](int64_t lhs, int64_t rhs) {
404 return lhs == rhs || lhs == 1 || rhs == 1;
405 };
406
407 if (!isPairBroadcastable(condDim, trueValDim) ||
408 !isPairBroadcastable(condDim, falseValDim) ||
409 !isPairBroadcastable(trueValDim, falseValDim)) {
410 return emitOpError(
411 "failed to verify that the shape of inputs: condition, "
412 "true_value, and false_value are compatible for "
413 "broadcasting");
414 }
415
416 int64_t bradcastedInputDim =
417 std::max(condDim, std::max(trueValDim, falseValDim));
418 if (bradcastedInputDim != resultDim) {
419 return emitOpError(
420 "failed to verify that the broadcast shape of inputs: condition, "
421 "true_value, and false_value is equal to "
422 "the output shape");
423 }
424 }
425 return success();
426}
427
428} // namespace mlir::spirv
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
SPIR-V TensorARM Type.
Definition SPIRVTypes.h:509
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *, DenseIntElementsAttr attr)
ParseResult parseSPIRV_I32_1DArmTensor(OpAsmParser &parser, DenseIntElementsAttr &attr)